1use cranelift::prelude::*;
55use cranelift_codegen::ir::{immediates::Offset32, Value};
56use cranelift_module::Module;
57
58use crate::{errors::EquationError, operators};
59
60#[derive(Debug, Clone, PartialEq)]
67pub struct VarRef {
68 pub name: String,
69 pub vec_ref: Value,
70 pub index: u32,
71}
72
73#[derive(Debug, Clone, PartialEq)]
83pub enum Expr {
84 Const(f64),
86 Var(VarRef),
88 Add(Box<Expr>, Box<Expr>),
90 Mul(Box<Expr>, Box<Expr>),
92 Sub(Box<Expr>, Box<Expr>),
94 Div(Box<Expr>, Box<Expr>),
96 Abs(Box<Expr>),
98 Pow(Box<Expr>, i64),
100 PowFloat(Box<Expr>, f64),
102 PowExpr(Box<Expr>, Box<Expr>),
104 Exp(Box<Expr>),
106 Ln(Box<Expr>),
108 Sqrt(Box<Expr>),
110 Sin(Box<Expr>),
112 Cos(Box<Expr>),
114 Neg(Box<Expr>),
116 Cached(Box<Expr>, Option<f64>),
118}
119
120#[derive(Debug, Clone)]
122pub enum LinearOp {
123 LoadConst(f64),
125 LoadVar(u32),
127 Add,
129 Sub,
131 Mul,
133 Div,
135 Abs,
137 Neg,
139 PowConst(i64),
141 PowFloat(f64),
143 PowExpr,
145 Exp,
147 Ln,
149 Sqrt,
151 Sin,
153 Cos,
155}
156
157#[derive(Debug, Clone)]
159pub struct FlattenedExpr {
160 pub ops: Vec<LinearOp>,
162 pub max_var_index: Option<u32>,
164 pub constant_result: Option<f64>,
166}
167
168impl Expr {
169 pub fn pre_evaluate(
171 &self,
172 var_cache: &mut std::collections::HashMap<String, f64>,
173 ) -> Box<Expr> {
174 match self {
175 Expr::Const(_) => Box::new(self.clone()),
176
177 Expr::Var(var_ref) => {
178 if let Some(&value) = var_cache.get(&var_ref.name) {
180 Box::new(Expr::Const(value))
181 } else {
182 Box::new(self.clone())
183 }
184 }
185
186 Expr::Add(left, right) => {
187 let l = left.pre_evaluate(var_cache);
188 let r = right.pre_evaluate(var_cache);
189 match (&*l, &*r) {
190 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a + b)),
191 _ => Box::new(Expr::Add(l, r)),
192 }
193 }
194
195 Expr::Sub(left, right) => {
196 let l = left.pre_evaluate(var_cache);
197 let r = right.pre_evaluate(var_cache);
198 match (&*l, &*r) {
199 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a - b)),
200 _ => Box::new(Expr::Sub(l, r)),
201 }
202 }
203
204 Expr::Mul(left, right) => {
205 let l = left.pre_evaluate(var_cache);
206 let r = right.pre_evaluate(var_cache);
207 match (&*l, &*r) {
208 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a * b)),
209 _ => Box::new(Expr::Mul(l, r)),
210 }
211 }
212
213 Expr::Div(left, right) => {
214 let l = left.pre_evaluate(var_cache);
215 let r = right.pre_evaluate(var_cache);
216 match (&*l, &*r) {
217 (Expr::Const(a), Expr::Const(b)) if *b != 0.0 => Box::new(Expr::Const(a / b)),
218 _ => Box::new(Expr::Div(l, r)),
219 }
220 }
221
222 Expr::Abs(expr) => {
223 let e = expr.pre_evaluate(var_cache);
224 match &*e {
225 Expr::Const(a) => Box::new(Expr::Const(a.abs())),
226 _ => Box::new(Expr::Abs(e)),
227 }
228 }
229
230 Expr::Neg(expr) => {
231 let e = expr.pre_evaluate(var_cache);
232 match &*e {
233 Expr::Const(a) => Box::new(Expr::Const(-a)),
234 _ => Box::new(Expr::Neg(e)),
235 }
236 }
237
238 Expr::Pow(base, exp) => {
239 let b = base.pre_evaluate(var_cache);
240 match &*b {
241 Expr::Const(a) => Box::new(Expr::Const(a.powi(*exp as i32))),
242 _ => Box::new(Expr::Pow(b, *exp)),
243 }
244 }
245
246 Expr::PowFloat(base, exp) => {
247 let b = base.pre_evaluate(var_cache);
248 match &*b {
249 Expr::Const(a) => Box::new(Expr::Const(a.powf(*exp))),
250 _ => Box::new(Expr::PowFloat(b, *exp)),
251 }
252 }
253
254 Expr::PowExpr(base, exponent) => {
255 let b = base.pre_evaluate(var_cache);
256 let e = exponent.pre_evaluate(var_cache);
257 match (&*b, &*e) {
258 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a.powf(*b))),
259 _ => Box::new(Expr::PowExpr(b, e)),
260 }
261 }
262
263 Expr::Exp(expr) => {
264 let e = expr.pre_evaluate(var_cache);
265 match &*e {
266 Expr::Const(a) => Box::new(Expr::Const(a.exp())),
267 _ => Box::new(Expr::Exp(e)),
268 }
269 }
270
271 Expr::Ln(expr) => {
272 let e = expr.pre_evaluate(var_cache);
273 match &*e {
274 Expr::Const(a) if *a > 0.0 => Box::new(Expr::Const(a.ln())),
275 _ => Box::new(Expr::Ln(e)),
276 }
277 }
278
279 Expr::Sqrt(expr) => {
280 let e = expr.pre_evaluate(var_cache);
281 match &*e {
282 Expr::Const(a) if *a >= 0.0 => Box::new(Expr::Const(a.sqrt())),
283 _ => Box::new(Expr::Sqrt(e)),
284 }
285 }
286
287 Expr::Sin(expr) => {
288 let e = expr.pre_evaluate(var_cache);
289 match &*e {
290 Expr::Const(a) => Box::new(Expr::Const(a.sin())),
291 _ => Box::new(Expr::Sin(e)),
292 }
293 }
294
295 Expr::Cos(expr) => {
296 let e = expr.pre_evaluate(var_cache);
297 match &*e {
298 Expr::Const(a) => Box::new(Expr::Const(a.cos())),
299 _ => Box::new(Expr::Cos(e)),
300 }
301 }
302
303 Expr::Cached(expr, _) => expr.pre_evaluate(var_cache),
304 }
305 }
306
307 pub fn derivative(&self, with_respect_to: &str) -> Box<Expr> {
330 match self {
331 Expr::Const(_) => Box::new(Expr::Const(0.0)),
332
333 Expr::Var(var_ref) => {
334 if var_ref.name == with_respect_to {
335 Box::new(Expr::Const(1.0))
336 } else {
337 Box::new(Expr::Const(0.0))
338 }
339 }
340
341 Expr::Add(left, right) => {
342 Box::new(Expr::Add(
344 left.derivative(with_respect_to),
345 right.derivative(with_respect_to),
346 ))
347 }
348
349 Expr::Sub(left, right) => {
350 Box::new(Expr::Sub(
352 left.derivative(with_respect_to),
353 right.derivative(with_respect_to),
354 ))
355 }
356
357 Expr::Mul(left, right) => {
358 Box::new(Expr::Add(
360 Box::new(Expr::Mul(left.clone(), right.derivative(with_respect_to))),
361 Box::new(Expr::Mul(right.clone(), left.derivative(with_respect_to))),
362 ))
363 }
364
365 Expr::Div(left, right) => {
366 Box::new(Expr::Div(
368 Box::new(Expr::Sub(
369 Box::new(Expr::Mul(right.clone(), left.derivative(with_respect_to))),
370 Box::new(Expr::Mul(left.clone(), right.derivative(with_respect_to))),
371 )),
372 Box::new(Expr::Pow(right.clone(), 2)),
373 ))
374 }
375
376 Expr::Abs(expr) => {
377 Box::new(Expr::Mul(
379 Box::new(Expr::Div(expr.clone(), Box::new(Expr::Abs(expr.clone())))),
380 expr.derivative(with_respect_to),
381 ))
382 }
383
384 Expr::Pow(base, exp) => {
385 Box::new(Expr::Mul(
387 Box::new(Expr::Mul(
388 Box::new(Expr::Const(*exp as f64)),
389 Box::new(Expr::Pow(base.clone(), exp - 1)),
390 )),
391 base.derivative(with_respect_to),
392 ))
393 }
394
395 Expr::PowFloat(base, exp) => {
396 Box::new(Expr::Mul(
398 Box::new(Expr::Mul(
399 Box::new(Expr::Const(*exp)),
400 Box::new(Expr::PowFloat(base.clone(), exp - 1.0)),
401 )),
402 base.derivative(with_respect_to),
403 ))
404 }
405
406 Expr::PowExpr(base, exponent) => {
407 Box::new(Expr::Mul(
410 Box::new(Expr::PowExpr(base.clone(), exponent.clone())),
411 Box::new(Expr::Add(
412 Box::new(Expr::Mul(
413 exponent.derivative(with_respect_to),
414 Box::new(Expr::Ln(base.clone())),
415 )),
416 Box::new(Expr::Mul(
417 exponent.clone(),
418 Box::new(Expr::Div(base.derivative(with_respect_to), base.clone())),
419 )),
420 )),
421 ))
422 }
423
424 Expr::Exp(expr) => {
425 Box::new(Expr::Mul(
427 Box::new(Expr::Exp(expr.clone())),
428 expr.derivative(with_respect_to),
429 ))
430 }
431
432 Expr::Ln(expr) => {
433 Box::new(Expr::Mul(
435 Box::new(Expr::Div(Box::new(Expr::Const(1.0)), expr.clone())),
436 expr.derivative(with_respect_to),
437 ))
438 }
439
440 Expr::Sqrt(expr) => {
441 Box::new(Expr::Mul(
443 Box::new(Expr::Div(
444 Box::new(Expr::Const(1.0)),
445 Box::new(Expr::Sqrt(expr.clone())),
446 )),
447 expr.derivative(with_respect_to),
448 ))
449 }
450
451 Expr::Sin(expr) => {
452 Box::new(Expr::Mul(
454 Box::new(Expr::Cos(expr.clone())),
455 expr.derivative(with_respect_to),
456 ))
457 }
458
459 Expr::Cos(expr) => {
460 Box::new(Expr::Mul(
462 Box::new(Expr::Neg(Box::new(Expr::Sin(expr.clone())))),
463 expr.derivative(with_respect_to),
464 ))
465 }
466
467 Expr::Neg(expr) => {
468 Box::new(Expr::Neg(expr.derivative(with_respect_to)))
470 }
471
472 Expr::Cached(expr, _) => expr.derivative(with_respect_to),
473 }
474 }
475
476 pub fn simplify(&self) -> Box<Expr> {
507 match self {
508 Expr::Const(_) | Expr::Var(_) => Box::new(self.clone()),
510
511 Expr::Add(left, right) => {
512 let l = left.simplify();
513 let r = right.simplify();
514 match (&*l, &*r) {
515 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a + b)),
517 (expr, Expr::Const(0.0)) | (Expr::Const(0.0), expr) => Box::new(expr.clone()),
519 (Expr::Mul(a1, x1), Expr::Mul(a2, x2)) if x1 == x2 => {
521 let combined_coeff = Expr::Add(a1.clone(), a2.clone()).simplify();
522 Box::new(Expr::Mul(combined_coeff, x1.clone()))
523 }
524 (Expr::Add(x, c1), c2)
526 if matches!(**c1, Expr::Const(_)) && matches!(*c2, Expr::Const(_)) =>
527 {
528 Box::new(Expr::Add(
529 x.clone(),
530 Expr::Add(c1.clone(), Box::new(c2.clone())).simplify(),
531 ))
532 }
533 _ => Box::new(Expr::Add(l, r)),
534 }
535 }
536
537 Expr::Sub(left, right) => {
538 let l = left.simplify();
539 let r = right.simplify();
540 match (&*l, &*r) {
541 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a - b)),
543 (expr, Expr::Const(0.0)) => Box::new(expr.clone()),
545 (a, b) if a == b => Box::new(Expr::Const(0.0)),
547 (Expr::Mul(a1, x1), Expr::Mul(a2, x2)) if x1 == x2 => {
549 let combined_coeff = Expr::Sub(a1.clone(), a2.clone()).simplify();
550 Box::new(Expr::Mul(combined_coeff, x1.clone()))
551 }
552 (x, Expr::Const(c)) => {
554 Box::new(Expr::Add(Box::new(x.clone()), Box::new(Expr::Const(-c))))
555 }
556 _ => Box::new(Expr::Sub(l, r)),
557 }
558 }
559
560 Expr::Mul(left, right) => {
561 let l = left.simplify();
562 let r = right.simplify();
563
564 if l == r {
566 return Box::new(Expr::Pow(l, 2)); }
568
569 match (&*l, &*r) {
570 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a * b)),
572 (Expr::Const(0.0), _) | (_, Expr::Const(0.0)) => Box::new(Expr::Const(0.0)),
574 (expr, Expr::Const(1.0)) | (Expr::Const(1.0), expr) => Box::new(expr.clone()),
576 (expr, Expr::Const(-1.0)) | (Expr::Const(-1.0), expr) => {
578 Box::new(Expr::Neg(Box::new(expr.clone())))
579 }
580 (Expr::Pow(b1, e1), Expr::Pow(b2, e2)) if b1 == b2 => {
582 Box::new(Expr::Pow(b1.clone(), e1 + e2))
583 }
584 (Expr::Const(c), Expr::Add(x, y)) | (Expr::Add(x, y), Expr::Const(c))
586 if c.abs() < 10.0 =>
587 {
588 Expr::Add(
589 Box::new(Expr::Mul(Box::new(Expr::Const(*c)), x.clone())),
590 Box::new(Expr::Mul(Box::new(Expr::Const(*c)), y.clone())),
591 )
592 .simplify()
593 }
594 (expr, Expr::Const(2.0)) | (Expr::Const(2.0), expr) => {
596 Box::new(Expr::Add(Box::new(expr.clone()), Box::new(expr.clone())))
597 }
598 (Expr::Mul(c1, x), c2)
600 if matches!(**c1, Expr::Const(_)) && matches!(*c2, Expr::Const(_)) =>
601 {
602 Box::new(Expr::Mul(
603 Expr::Mul(c1.clone(), Box::new(c2.clone())).simplify(),
604 x.clone(),
605 ))
606 }
607 _ => Box::new(Expr::Mul(l, r)),
608 }
609 }
610
611 Expr::Div(left, right) => {
612 let l = left.simplify();
613 let r = right.simplify();
614 match (&*l, &*r) {
615 (Expr::Const(a), Expr::Const(b)) if *b != 0.0 => Box::new(Expr::Const(a / b)),
617 (Expr::Const(0.0), _) => Box::new(Expr::Const(0.0)),
619 (expr, Expr::Const(1.0)) => Box::new(expr.clone()),
621 (expr, Expr::Const(-1.0)) => Box::new(Expr::Neg(Box::new(expr.clone()))),
623 (a, b) if a == b => Box::new(Expr::Const(1.0)),
625 (Expr::Pow(b1, e1), Expr::Pow(b2, e2)) if b1 == b2 => {
627 Box::new(Expr::Pow(b1.clone(), e1 - e2))
628 }
629 (x, Expr::Const(c)) if *c != 0.0 && c.abs() > 1e-10 => Box::new(Expr::Mul(
631 Box::new(x.clone()),
632 Box::new(Expr::Const(1.0 / c)),
633 )),
634 (Expr::Div(x, y), z) => Box::new(Expr::Div(
636 x.clone(),
637 Box::new(Expr::Mul(y.clone(), Box::new(z.clone()))),
638 )),
639 _ => Box::new(Expr::Div(l, r)),
640 }
641 }
642
643 Expr::Abs(expr) => {
644 let e = expr.simplify();
645 match &*e {
646 Expr::Const(a) => Box::new(Expr::Const(a.abs())),
648 Expr::Abs(inner) => Box::new(Expr::Abs(inner.clone())),
650 Expr::Neg(inner) => Box::new(Expr::Abs(inner.clone())),
652 Expr::Pow(_, exp) if exp % 2 == 0 => e,
654 _ => Box::new(Expr::Abs(e)),
655 }
656 }
657
658 Expr::Pow(base, exp) => {
659 let b = base.simplify();
660 match (&*b, exp) {
661 (_, 0) => Box::new(Expr::Const(1.0)),
663 (Expr::Const(a), exp) => Box::new(Expr::Const(a.powi(*exp as i32))),
665 (expr, 1) => Box::new(expr.clone()),
667 (expr, exp) if *exp < 0 => Box::new(Expr::Div(
669 Box::new(Expr::Const(1.0)),
670 Box::new(Expr::Pow(Box::new(expr.clone()), -exp)),
671 )),
672 (Expr::Pow(inner_base, inner_exp), outer_exp) => {
674 Box::new(Expr::Pow(inner_base.clone(), inner_exp * outer_exp))
675 }
676 (Expr::Mul(x, y), n) if *n >= 2 && *n <= 4 => Box::new(Expr::Mul(
678 Box::new(Expr::Pow(x.clone(), *n)),
679 Box::new(Expr::Pow(y.clone(), *n)),
680 )),
681 _ => Box::new(Expr::Pow(b, *exp)),
682 }
683 }
684
685 Expr::PowFloat(base, exp) => {
686 let b = base.simplify();
687 match (&*b, exp) {
688 (_, exp) if exp.abs() < 1e-10 => Box::new(Expr::Const(1.0)),
690 (Expr::Const(a), exp) => Box::new(Expr::Const(a.powf(*exp))),
692 (expr, exp) if (exp - 1.0).abs() < 1e-10 => Box::new(expr.clone()),
694 (expr, exp) if exp.fract().abs() < 1e-10 => {
696 Box::new(Expr::Pow(Box::new(expr.clone()), *exp as i64))
697 }
698 _ => Box::new(Expr::PowFloat(b, *exp)),
699 }
700 }
701
702 Expr::PowExpr(base, exponent) => {
703 let b = base.simplify();
704 let e = exponent.simplify();
705 match (&*b, &*e) {
706 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a.powf(*b))),
708 (_, Expr::Const(0.0)) => Box::new(Expr::Const(1.0)),
710 (expr, Expr::Const(1.0)) => Box::new(expr.clone()),
712 (expr, Expr::Const(exp)) if exp.fract().abs() < 1e-10 => {
714 Box::new(Expr::Pow(Box::new(expr.clone()), *exp as i64))
715 }
716 (expr, Expr::Const(exp)) => {
717 Box::new(Expr::PowFloat(Box::new(expr.clone()), *exp))
718 }
719 _ => Box::new(Expr::PowExpr(b, e)),
720 }
721 }
722
723 Expr::Exp(expr) => {
724 let e = expr.simplify();
725 match &*e {
726 Expr::Const(0.0) => Box::new(Expr::Const(1.0)),
728 Expr::Const(a) => Box::new(Expr::Const(a.exp())),
730 Expr::Ln(inner) => inner.clone(),
732 Expr::Add(x, y) => Box::new(Expr::Mul(
734 Box::new(Expr::Exp(x.clone())),
735 Box::new(Expr::Exp(y.clone())),
736 )),
737 _ => Box::new(Expr::Exp(e)),
738 }
739 }
740
741 Expr::Ln(expr) => {
742 let e = expr.simplify();
743 match &*e {
744 Expr::Const(a) if *a > 0.0 => Box::new(Expr::Const(a.ln())),
746 Expr::Const(1.0) => Box::new(Expr::Const(0.0)),
748 Expr::Exp(inner) => inner.clone(),
750 Expr::Mul(x, y) => Box::new(Expr::Add(
752 Box::new(Expr::Ln(x.clone())),
753 Box::new(Expr::Ln(y.clone())),
754 )),
755 Expr::Div(x, y) => Box::new(Expr::Sub(
757 Box::new(Expr::Ln(x.clone())),
758 Box::new(Expr::Ln(y.clone())),
759 )),
760 Expr::Pow(x, n) => Box::new(Expr::Mul(
762 Box::new(Expr::Const(*n as f64)),
763 Box::new(Expr::Ln(x.clone())),
764 )),
765 _ => Box::new(Expr::Ln(e)),
766 }
767 }
768
769 Expr::Sqrt(expr) => {
770 let e = expr.simplify();
771 match &*e {
772 Expr::Const(a) if *a >= 0.0 => Box::new(Expr::Const(a.sqrt())),
774 Expr::Const(0.0) => Box::new(Expr::Const(0.0)),
776 Expr::Const(1.0) => Box::new(Expr::Const(1.0)),
778 Expr::Pow(x, 2) => Box::new(Expr::Abs(x.clone())),
780 Expr::Mul(x, y) => Box::new(Expr::Mul(
782 Box::new(Expr::Sqrt(x.clone())),
783 Box::new(Expr::Sqrt(y.clone())),
784 )),
785 _ => Box::new(Expr::Sqrt(e)),
786 }
787 }
788
789 Expr::Sin(expr) => {
790 let e = expr.simplify();
791 match &*e {
792 Expr::Const(0.0) => Box::new(Expr::Const(0.0)),
794 Expr::Const(a) => Box::new(Expr::Const(a.sin())),
796 _ => Box::new(Expr::Sin(e)),
797 }
798 }
799
800 Expr::Cos(expr) => {
801 let e = expr.simplify();
802 match &*e {
803 Expr::Const(0.0) => Box::new(Expr::Const(1.0)),
805 Expr::Const(a) => Box::new(Expr::Const(a.cos())),
807 _ => Box::new(Expr::Cos(e)),
808 }
809 }
810
811 Expr::Neg(expr) => {
812 let e = expr.simplify();
813 match &*e {
814 Expr::Const(a) => Box::new(Expr::Const(-a)),
816 Expr::Neg(inner) => inner.clone(),
818 Expr::Add(x, y) => {
820 Expr::Sub(Box::new(Expr::Neg(x.clone())), y.clone()).simplify()
821 }
822 Expr::Sub(x, y) => {
824 Expr::Add(Box::new(Expr::Neg(x.clone())), y.clone()).simplify()
825 }
826 Expr::Mul(c, x) if matches!(**c, Expr::Const(_)) => {
828 Expr::Mul(Box::new(Expr::Neg(c.clone())), x.clone()).simplify()
829 }
830 _ => Box::new(Expr::Neg(e)),
831 }
832 }
833
834 Expr::Cached(expr, cached_value) => {
835 if cached_value.is_some() {
836 Box::new(self.clone())
837 } else {
838 expr.simplify()
840 }
841 }
842 }
843 }
844
845 pub fn insert<F>(&self, predicate: F, replacement: &Expr) -> Box<Expr>
858 where
859 F: Fn(&Expr) -> bool + Clone,
860 {
861 if predicate(self) {
862 Box::new(replacement.clone())
863 } else {
864 match self {
865 Expr::Const(_) | Expr::Var(_) => Box::new(self.clone()),
866 Expr::Add(left, right) => Box::new(Expr::Add(
867 left.insert(predicate.clone(), replacement),
868 right.insert(predicate, replacement),
869 )),
870 Expr::Mul(left, right) => Box::new(Expr::Mul(
871 left.insert(predicate.clone(), replacement),
872 right.insert(predicate, replacement),
873 )),
874 Expr::Sub(left, right) => Box::new(Expr::Sub(
875 left.insert(predicate.clone(), replacement),
876 right.insert(predicate, replacement),
877 )),
878 Expr::Div(left, right) => Box::new(Expr::Div(
879 left.insert(predicate.clone(), replacement),
880 right.insert(predicate, replacement),
881 )),
882 Expr::Abs(expr) => Box::new(Expr::Abs(expr.insert(predicate, replacement))),
883 Expr::Pow(base, exp) => {
884 Box::new(Expr::Pow(base.insert(predicate, replacement), *exp))
885 }
886 Expr::PowFloat(base, exp) => {
887 Box::new(Expr::PowFloat(base.insert(predicate, replacement), *exp))
888 }
889 Expr::PowExpr(base, exponent) => Box::new(Expr::PowExpr(
890 base.insert(predicate.clone(), replacement),
891 exponent.insert(predicate, replacement),
892 )),
893 Expr::Exp(expr) => Box::new(Expr::Exp(expr.insert(predicate, replacement))),
894 Expr::Ln(expr) => Box::new(Expr::Ln(expr.insert(predicate, replacement))),
895 Expr::Sqrt(expr) => Box::new(Expr::Sqrt(expr.insert(predicate, replacement))),
896 Expr::Sin(expr) => Box::new(Expr::Sin(expr.insert(predicate, replacement))),
897 Expr::Cos(expr) => Box::new(Expr::Cos(expr.insert(predicate, replacement))),
898 Expr::Neg(expr) => Box::new(Expr::Neg(expr.insert(predicate, replacement))),
899 Expr::Cached(expr, _) => {
900 Box::new(Expr::Cached(expr.insert(predicate, replacement), None))
901 }
902 }
903 }
904 }
905
906 pub fn flatten(&self) -> FlattenedExpr {
917 let mut ops = Vec::new();
918 let mut max_var_index = None;
919
920 if let Some(constant) = self.try_evaluate_constant() {
922 return FlattenedExpr {
923 ops: vec![LinearOp::LoadConst(constant)],
924 max_var_index: None,
925 constant_result: Some(constant),
926 };
927 }
928
929 self.flatten_recursive(&mut ops, &mut max_var_index);
930
931 FlattenedExpr {
932 ops,
933 max_var_index,
934 constant_result: None,
935 }
936 }
937
938 fn try_evaluate_constant(&self) -> Option<f64> {
940 match self {
941 Expr::Const(val) => Some(*val),
942 Expr::Var(_) => None,
943 Expr::Add(left, right) => {
944 Some(left.try_evaluate_constant()? + right.try_evaluate_constant()?)
945 }
946 Expr::Sub(left, right) => {
947 Some(left.try_evaluate_constant()? - right.try_evaluate_constant()?)
948 }
949 Expr::Mul(left, right) => {
950 Some(left.try_evaluate_constant()? * right.try_evaluate_constant()?)
951 }
952 Expr::Div(left, right) => {
953 let r = right.try_evaluate_constant()?;
954 if r.abs() < 1e-10 {
955 return None;
956 }
957 Some(left.try_evaluate_constant()? / r)
958 }
959 Expr::Abs(expr) => Some(expr.try_evaluate_constant()?.abs()),
960 Expr::Neg(expr) => Some(-expr.try_evaluate_constant()?),
961 Expr::Pow(base, exp) => Some(base.try_evaluate_constant()?.powi(*exp as i32)),
962 Expr::PowFloat(base, exp) => Some(base.try_evaluate_constant()?.powf(*exp)),
963 Expr::PowExpr(base, exponent) => Some(
964 base.try_evaluate_constant()?
965 .powf(exponent.try_evaluate_constant()?),
966 ),
967 Expr::Exp(expr) => Some(expr.try_evaluate_constant()?.exp()),
968 Expr::Ln(expr) => {
969 let val = expr.try_evaluate_constant()?;
970 if val <= 0.0 {
971 return None;
972 }
973 Some(val.ln())
974 }
975 Expr::Sqrt(expr) => {
976 let val = expr.try_evaluate_constant()?;
977 if val < 0.0 {
978 return None;
979 }
980 Some(val.sqrt())
981 }
982 Expr::Sin(expr) => Some(expr.try_evaluate_constant()?.sin()),
983 Expr::Cos(expr) => Some(expr.try_evaluate_constant()?.cos()),
984 Expr::Cached(expr, cached_value) => {
985 cached_value.or_else(|| expr.try_evaluate_constant())
986 }
987 }
988 }
989
990 fn flatten_recursive(&self, ops: &mut Vec<LinearOp>, max_var_index: &mut Option<u32>) {
992 match self {
993 Expr::Const(val) => {
994 ops.push(LinearOp::LoadConst(*val));
995 }
996
997 Expr::Var(var_ref) => {
998 let index = var_ref.index;
999 *max_var_index = Some(max_var_index.unwrap_or(0).max(index));
1000 ops.push(LinearOp::LoadVar(index));
1001 }
1002
1003 Expr::Add(left, right) => {
1004 left.flatten_recursive(ops, max_var_index);
1005 right.flatten_recursive(ops, max_var_index);
1006 ops.push(LinearOp::Add);
1007 }
1008
1009 Expr::Sub(left, right) => {
1010 left.flatten_recursive(ops, max_var_index);
1011 right.flatten_recursive(ops, max_var_index);
1012 ops.push(LinearOp::Sub);
1013 }
1014
1015 Expr::Mul(left, right) => {
1016 left.flatten_recursive(ops, max_var_index);
1017 right.flatten_recursive(ops, max_var_index);
1018 ops.push(LinearOp::Mul);
1019 }
1020
1021 Expr::Div(left, right) => {
1022 left.flatten_recursive(ops, max_var_index);
1023 right.flatten_recursive(ops, max_var_index);
1024 ops.push(LinearOp::Div);
1025 }
1026
1027 Expr::Abs(expr) => {
1028 expr.flatten_recursive(ops, max_var_index);
1029 ops.push(LinearOp::Abs);
1030 }
1031
1032 Expr::Neg(expr) => {
1033 expr.flatten_recursive(ops, max_var_index);
1034 ops.push(LinearOp::Neg);
1035 }
1036
1037 Expr::Pow(base, exp) => {
1038 base.flatten_recursive(ops, max_var_index);
1039 ops.push(LinearOp::PowConst(*exp));
1040 }
1041
1042 Expr::PowFloat(base, exp) => {
1043 base.flatten_recursive(ops, max_var_index);
1044 ops.push(LinearOp::PowFloat(*exp));
1045 }
1046
1047 Expr::PowExpr(base, exponent) => {
1048 base.flatten_recursive(ops, max_var_index);
1049 exponent.flatten_recursive(ops, max_var_index);
1050 ops.push(LinearOp::PowExpr);
1051 }
1052
1053 Expr::Exp(expr) => {
1054 expr.flatten_recursive(ops, max_var_index);
1055 ops.push(LinearOp::Exp);
1056 }
1057
1058 Expr::Ln(expr) => {
1059 expr.flatten_recursive(ops, max_var_index);
1060 ops.push(LinearOp::Ln);
1061 }
1062
1063 Expr::Sqrt(expr) => {
1064 expr.flatten_recursive(ops, max_var_index);
1065 ops.push(LinearOp::Sqrt);
1066 }
1067
1068 Expr::Sin(expr) => {
1069 expr.flatten_recursive(ops, max_var_index);
1070 ops.push(LinearOp::Sin);
1071 }
1072
1073 Expr::Cos(expr) => {
1074 expr.flatten_recursive(ops, max_var_index);
1075 ops.push(LinearOp::Cos);
1076 }
1077
1078 Expr::Cached(expr, cached_value) => {
1079 if let Some(val) = cached_value {
1080 ops.push(LinearOp::LoadConst(*val));
1081 } else {
1082 expr.flatten_recursive(ops, max_var_index);
1083 }
1084 }
1085 }
1086 }
1087
1088 pub fn codegen_flattened(
1093 &self,
1094 builder: &mut FunctionBuilder,
1095 module: &mut dyn Module,
1096 ) -> Result<Value, EquationError> {
1097 let flattened = self.flatten();
1098
1099 if let Some(constant) = flattened.constant_result {
1101 return Ok(builder.ins().f64const(constant));
1102 }
1103
1104 let mut value_stack = Vec::with_capacity(flattened.ops.len());
1106
1107 let input_ptr = builder
1109 .func
1110 .dfg
1111 .block_params(builder.current_block().unwrap())[0];
1112
1113 for op in &flattened.ops {
1115 match op {
1116 LinearOp::LoadConst(val) => {
1117 value_stack.push(builder.ins().f64const(*val));
1118 }
1119
1120 LinearOp::LoadVar(index) => {
1121 let offset = (*index as i32) * 8;
1122 let memflags = MemFlags::new().with_aligned().with_readonly().with_notrap();
1123 let val =
1124 builder
1125 .ins()
1126 .load(types::F64, memflags, input_ptr, Offset32::new(offset));
1127 value_stack.push(val);
1128 }
1129
1130 LinearOp::Add => {
1131 let rhs = value_stack.pop().unwrap();
1132 let lhs = value_stack.pop().unwrap();
1133 value_stack.push(builder.ins().fadd(lhs, rhs));
1134 }
1135
1136 LinearOp::Sub => {
1137 let rhs = value_stack.pop().unwrap();
1138 let lhs = value_stack.pop().unwrap();
1139 value_stack.push(builder.ins().fsub(lhs, rhs));
1140 }
1141
1142 LinearOp::Mul => {
1143 let rhs = value_stack.pop().unwrap();
1144 let lhs = value_stack.pop().unwrap();
1145 value_stack.push(builder.ins().fmul(lhs, rhs));
1146 }
1147
1148 LinearOp::Div => {
1149 let rhs = value_stack.pop().unwrap();
1150 let lhs = value_stack.pop().unwrap();
1151 value_stack.push(builder.ins().fdiv(lhs, rhs));
1152 }
1153
1154 LinearOp::Abs => {
1155 let val = value_stack.pop().unwrap();
1156 value_stack.push(builder.ins().fabs(val));
1157 }
1158
1159 LinearOp::Neg => {
1160 let val = value_stack.pop().unwrap();
1161 value_stack.push(builder.ins().fneg(val));
1162 }
1163
1164 LinearOp::PowConst(exp) => {
1165 let base = value_stack.pop().unwrap();
1166 let result = match *exp {
1167 0 => builder.ins().f64const(1.0),
1168 1 => base,
1169 2 => builder.ins().fmul(base, base),
1170 3 => {
1171 let square = builder.ins().fmul(base, base);
1172 builder.ins().fmul(square, base)
1173 }
1174 4 => {
1175 let square = builder.ins().fmul(base, base);
1176 builder.ins().fmul(square, square)
1177 }
1178 -1 => {
1179 let one = builder.ins().f64const(1.0);
1180 builder.ins().fdiv(one, base)
1181 }
1182 -2 => {
1183 let square = builder.ins().fmul(base, base);
1184 let one = builder.ins().f64const(1.0);
1185 builder.ins().fdiv(one, square)
1186 }
1187 _ => {
1188 generate_optimized_power(builder, base, *exp)
1190 }
1191 };
1192 value_stack.push(result);
1193 }
1194
1195 LinearOp::PowFloat(exp) => {
1196 let base = value_stack.pop().unwrap();
1197 let func_id = crate::operators::pow::link_powf(module).unwrap();
1198 let exp_val = builder.ins().f64const(*exp);
1199 let result =
1200 crate::operators::pow::call_powf(builder, module, func_id, base, exp_val);
1201 value_stack.push(result);
1202 }
1203
1204 LinearOp::PowExpr => {
1205 let exponent = value_stack.pop().unwrap();
1206 let base = value_stack.pop().unwrap();
1207 let func_id = crate::operators::pow::link_powf(module).unwrap();
1208 let result =
1209 crate::operators::pow::call_powf(builder, module, func_id, base, exponent);
1210 value_stack.push(result);
1211 }
1212
1213 LinearOp::Exp => {
1214 let arg = value_stack.pop().unwrap();
1215 let func_id = operators::exp::link_exp(module).unwrap();
1216 let result = operators::exp::call_exp(builder, module, func_id, arg);
1217 value_stack.push(result);
1218 }
1219
1220 LinearOp::Ln => {
1221 let arg = value_stack.pop().unwrap();
1222 let func_id = operators::ln::link_ln(module).unwrap();
1223 let result = operators::ln::call_ln(builder, module, func_id, arg);
1224 value_stack.push(result);
1225 }
1226
1227 LinearOp::Sqrt => {
1228 let arg = value_stack.pop().unwrap();
1229 let func_id = operators::sqrt::link_sqrt(module).unwrap();
1230 let result = operators::sqrt::call_sqrt(builder, module, func_id, arg);
1231 value_stack.push(result);
1232 }
1233
1234 LinearOp::Sin => {
1235 let arg = value_stack.pop().unwrap();
1236 let func_id = crate::operators::trigonometric::link_sin(module).unwrap();
1237 let result =
1238 crate::operators::trigonometric::call_sin(builder, module, func_id, arg);
1239 value_stack.push(result);
1240 }
1241
1242 LinearOp::Cos => {
1243 let arg = value_stack.pop().unwrap();
1244 let func_id = crate::operators::trigonometric::link_cos(module).unwrap();
1245 let result =
1246 crate::operators::trigonometric::call_cos(builder, module, func_id, arg);
1247 value_stack.push(result);
1248 }
1249 }
1250 }
1251
1252 Ok(value_stack.pop().unwrap())
1254 }
1255}
1256
1257fn generate_optimized_power(builder: &mut FunctionBuilder, base: Value, exp: i64) -> Value {
1259 if exp == 0 {
1260 return builder.ins().f64const(1.0);
1261 }
1262
1263 if exp == 1 {
1264 return base;
1265 }
1266
1267 let abs_exp = exp.abs();
1268 let mut result = builder.ins().f64const(1.0);
1269 let mut current_base = base;
1270 let mut remaining = abs_exp;
1271
1272 while remaining > 0 {
1274 if remaining & 1 == 1 {
1275 result = builder.ins().fmul(result, current_base);
1276 }
1277 if remaining > 1 {
1278 current_base = builder.ins().fmul(current_base, current_base);
1279 }
1280 remaining >>= 1;
1281 }
1282
1283 if exp < 0 {
1284 let one = builder.ins().f64const(1.0);
1285 builder.ins().fdiv(one, result)
1286 } else {
1287 result
1288 }
1289}
1290
1291impl std::fmt::Display for Expr {
1303 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1304 match self {
1305 Expr::Const(val) => write!(f, "{val}"),
1306 Expr::Var(var_ref) => write!(f, "{0}", var_ref.name),
1307 Expr::Add(left, right) => write!(f, "({left} + {right})"),
1308 Expr::Mul(left, right) => write!(f, "({left} * {right})"),
1309 Expr::Sub(left, right) => write!(f, "({left} - {right})"),
1310 Expr::Div(left, right) => write!(f, "({left} / {right})"),
1311 Expr::Abs(expr) => write!(f, "|{expr}|"),
1312 Expr::Pow(base, exp) => write!(f, "({base}^{exp})"),
1313 Expr::PowFloat(base, exp) => write!(f, "({base}^{exp})"),
1314 Expr::PowExpr(base, exponent) => write!(f, "({base}^{exponent})"),
1315 Expr::Exp(expr) => write!(f, "exp({expr})"),
1316 Expr::Ln(expr) => write!(f, "ln({expr})"),
1317 Expr::Sqrt(expr) => write!(f, "sqrt({expr})"),
1318 Expr::Sin(expr) => write!(f, "sin({expr})"),
1319 Expr::Cos(expr) => write!(f, "cos({expr})"),
1320 Expr::Neg(expr) => write!(f, "-({expr})"),
1321 Expr::Cached(expr, _) => write!(f, "{expr}"),
1322 }
1323 }
1324}
1325
1326#[cfg(test)]
1327mod tests {
1328 use super::*;
1329
1330 fn var(name: &str) -> Box<Expr> {
1332 Box::new(Expr::Var(VarRef {
1333 name: name.to_string(),
1334 vec_ref: Value::from_u32(0),
1335 index: 0,
1336 }))
1337 }
1338
1339 #[test]
1340 fn test_simplify() {
1341 fn var(name: &str) -> Box<Expr> {
1343 Box::new(Expr::Var(VarRef {
1344 name: name.to_string(),
1345 vec_ref: Value::from_u32(0), index: 0,
1347 }))
1348 }
1349
1350 assert_eq!(
1353 *Expr::Add(Box::new(Expr::Const(2.0)), Box::new(Expr::Const(3.0))).simplify(),
1354 Expr::Const(5.0)
1355 );
1356
1357 assert_eq!(
1360 *Expr::Add(var("x"), Box::new(Expr::Const(0.0))).simplify(),
1361 *var("x")
1362 );
1363
1364 assert_eq!(
1367 *Expr::Mul(var("x"), Box::new(Expr::Const(1.0))).simplify(),
1368 *var("x")
1369 );
1370
1371 assert_eq!(
1374 *Expr::Mul(var("x"), Box::new(Expr::Const(0.0))).simplify(),
1375 Expr::Const(0.0)
1376 );
1377
1378 assert_eq!(
1381 *Expr::Div(var("x"), Box::new(Expr::Const(1.0))).simplify(),
1382 *var("x")
1383 );
1384
1385 assert_eq!(*Expr::Div(var("x"), var("x")).simplify(), Expr::Const(1.0));
1388
1389 assert_eq!(*Expr::Pow(var("x"), 0).simplify(), Expr::Const(1.0));
1392 assert_eq!(*Expr::Pow(var("x"), 1).simplify(), *var("x"));
1394
1395 assert_eq!(
1398 *Expr::Abs(Box::new(Expr::Const(-3.0))).simplify(),
1399 Expr::Const(3.0)
1400 );
1401
1402 assert_eq!(
1405 *Expr::Abs(Box::new(Expr::Abs(var("x")))).simplify(),
1406 Expr::Abs(var("x"))
1407 );
1408 }
1409
1410 #[test]
1411 fn test_insert() {
1412 fn var(name: &str) -> Box<Expr> {
1414 Box::new(Expr::Var(VarRef {
1415 name: name.to_string(),
1416 vec_ref: Value::from_u32(0),
1417 index: 0,
1418 }))
1419 }
1420
1421 let expr = Box::new(Expr::Add(var("x"), var("y")));
1423
1424 let replacement = Box::new(Expr::Mul(Box::new(Expr::Const(2.0)), var("z")));
1426
1427 let result = expr.insert(|e| matches!(e, Expr::Var(v) if v.name == "x"), &replacement);
1428
1429 assert_eq!(
1431 *result,
1432 Expr::Add(
1433 Box::new(Expr::Mul(Box::new(Expr::Const(2.0)), var("z"),)),
1434 var("y"),
1435 )
1436 );
1437 }
1438
1439 #[test]
1440 fn test_derivative() {
1441 assert_eq!(*Expr::Const(5.0).derivative("x"), Expr::Const(0.0));
1443
1444 assert_eq!(*var("x").derivative("x"), Expr::Const(1.0));
1446 assert_eq!(*var("y").derivative("x"), Expr::Const(0.0));
1447
1448 let sum = Box::new(Expr::Add(var("x"), var("y")));
1450 assert_eq!(
1451 *sum.derivative("x"),
1452 Expr::Add(Box::new(Expr::Const(1.0)), Box::new(Expr::Const(0.0)))
1453 );
1454
1455 let product = Box::new(Expr::Mul(var("x"), var("y")));
1457 assert_eq!(
1458 *product.derivative("x"),
1459 Expr::Add(
1460 Box::new(Expr::Mul(var("x"), Box::new(Expr::Const(0.0)))),
1461 Box::new(Expr::Mul(var("y"), Box::new(Expr::Const(1.0))))
1462 )
1463 );
1464
1465 let power = Box::new(Expr::Pow(var("x"), 3));
1467 assert_eq!(
1468 *power.derivative("x"),
1469 Expr::Mul(
1470 Box::new(Expr::Mul(
1471 Box::new(Expr::Const(3.0)),
1472 Box::new(Expr::Pow(var("x"), 2))
1473 )),
1474 Box::new(Expr::Const(1.0))
1475 )
1476 );
1477 }
1478
1479 #[test]
1480 fn test_complex_simplifications() {
1481 let expr = Box::new(Expr::Mul(
1483 Box::new(Expr::Add(var("x"), Box::new(Expr::Const(0.0)))),
1484 Box::new(Expr::Add(var("y"), Box::new(Expr::Const(0.0)))),
1485 ));
1486 assert_eq!(*expr.simplify(), Expr::Mul(var("x"), var("y")));
1487
1488 let expr = Box::new(Expr::Neg(Box::new(Expr::Neg(var("x")))));
1490 assert_eq!(*expr.simplify(), *var("x"));
1491
1492 let expr = Box::new(Expr::Mul(
1494 Box::new(Expr::Mul(Box::new(Expr::Const(1.0)), var("x"))),
1495 Box::new(Expr::Mul(var("y"), Box::new(Expr::Const(1.0)))),
1496 ));
1497 assert_eq!(*expr.simplify(), Expr::Mul(var("x"), var("y")));
1498
1499 let div = Box::new(Expr::Div(var("x"), var("y")));
1501 let expr = Box::new(Expr::Div(div.clone(), div));
1502 assert_eq!(*expr.simplify(), Expr::Const(1.0));
1503 }
1504
1505 #[test]
1506 fn test_special_functions() {
1507 let expr = Box::new(Expr::Abs(Box::new(Expr::Abs(var("x")))));
1509 assert_eq!(*expr.simplify(), Expr::Abs(var("x")));
1510
1511 let expr = Box::new(Expr::Sqrt(Box::new(Expr::Pow(var("x"), 2))));
1513 assert_eq!(*expr.simplify(), Expr::Abs(var("x")));
1514
1515 assert_eq!(
1518 *Expr::Exp(Box::new(Expr::Const(0.0))).simplify(),
1519 Expr::Const(1.0)
1520 );
1521 assert_eq!(
1523 *Expr::Ln(Box::new(Expr::Const(1.0))).simplify(),
1524 Expr::Const(0.0)
1525 );
1526 }
1527
1528 #[test]
1529 fn test_display() {
1530 assert_eq!(format!("{}", Expr::Const(5.0)), "5");
1532 assert_eq!(format!("{}", *var("x")), "x");
1533
1534 let sum = Expr::Add(var("x"), var("y"));
1536 assert_eq!(format!("{sum}"), "(x + y)");
1537
1538 let product = Expr::Mul(var("x"), var("y"));
1539 assert_eq!(format!("{product}"), "(x * y)");
1540
1541 let exp = Expr::Exp(var("x"));
1543 assert_eq!(format!("{exp}"), "exp(x)");
1544
1545 let abs = Expr::Abs(var("x"));
1546 assert_eq!(format!("{abs}"), "|x|");
1547
1548 let complex = Expr::Div(
1550 Box::new(Expr::Add(Box::new(Expr::Pow(var("x"), 2)), var("y"))),
1551 var("z"),
1552 );
1553 assert_eq!(format!("{complex}"), "(((x^2) + y) / z)");
1554 }
1555
1556 #[test]
1557 fn test_cached_expressions() {
1558 let cached = Box::new(Expr::Cached(Box::new(Expr::Const(5.0)), Some(5.0)));
1560 assert_eq!(*cached.simplify(), *cached);
1561
1562 let uncached = Box::new(Expr::Cached(
1564 Box::new(Expr::Add(
1565 Box::new(Expr::Const(2.0)),
1566 Box::new(Expr::Const(3.0)),
1567 )),
1568 None,
1569 ));
1570 assert_eq!(*uncached.simplify(), Expr::Const(5.0));
1571 }
1572}