1use cranelift::prelude::*;
55use cranelift_codegen::ir::{immediates::Offset32, Value};
56use cranelift_module::Module;
57
58use crate::{errors::EquationError, operators};
59
60#[derive(Debug, Clone, PartialEq)]
67pub struct VarRef {
68 pub name: String,
69 pub vec_ref: Value,
70 pub index: u32,
71}
72
73#[derive(Debug, Clone, PartialEq)]
83pub enum Expr {
84 Const(f64),
86 Var(VarRef),
88 Add(Box<Expr>, Box<Expr>),
90 Mul(Box<Expr>, Box<Expr>),
92 Sub(Box<Expr>, Box<Expr>),
94 Div(Box<Expr>, Box<Expr>),
96 Abs(Box<Expr>),
98 Pow(Box<Expr>, i64),
100 PowFloat(Box<Expr>, f64),
102 PowExpr(Box<Expr>, Box<Expr>),
104 Exp(Box<Expr>),
106 Ln(Box<Expr>),
108 Sqrt(Box<Expr>),
110 Sin(Box<Expr>),
112 Cos(Box<Expr>),
114 Neg(Box<Expr>),
116 Cached(Box<Expr>, Option<f64>),
118}
119
120#[derive(Debug, Clone)]
122pub enum LinearOp {
123 LoadConst(f64),
125 LoadVar(u32),
127 Add,
129 Sub,
131 Mul,
133 Div,
135 Abs,
137 Neg,
139 PowConst(i64),
141 PowFloat(f64),
143 PowExpr,
145 Exp,
147 Ln,
149 Sqrt,
151 Sin,
153 Cos,
155}
156
157#[derive(Debug, Clone)]
159pub struct FlattenedExpr {
160 pub ops: Vec<LinearOp>,
162 pub max_var_index: Option<u32>,
164 pub constant_result: Option<f64>,
166}
167
168impl Expr {
169 pub fn pre_evaluate(
171 &self,
172 var_cache: &mut std::collections::HashMap<String, f64>,
173 ) -> Box<Expr> {
174 match self {
175 Expr::Const(_) => Box::new(self.clone()),
176
177 Expr::Var(var_ref) => {
178 if let Some(&value) = var_cache.get(&var_ref.name) {
180 Box::new(Expr::Const(value))
181 } else {
182 Box::new(self.clone())
183 }
184 }
185
186 Expr::Add(left, right) => {
187 let l = left.pre_evaluate(var_cache);
188 let r = right.pre_evaluate(var_cache);
189 match (&*l, &*r) {
190 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a + b)),
191 _ => Box::new(Expr::Add(l, r)),
192 }
193 }
194
195 Expr::Sub(left, right) => {
196 let l = left.pre_evaluate(var_cache);
197 let r = right.pre_evaluate(var_cache);
198 match (&*l, &*r) {
199 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a - b)),
200 _ => Box::new(Expr::Sub(l, r)),
201 }
202 }
203
204 Expr::Mul(left, right) => {
205 let l = left.pre_evaluate(var_cache);
206 let r = right.pre_evaluate(var_cache);
207 match (&*l, &*r) {
208 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a * b)),
209 _ => Box::new(Expr::Mul(l, r)),
210 }
211 }
212
213 Expr::Div(left, right) => {
214 let l = left.pre_evaluate(var_cache);
215 let r = right.pre_evaluate(var_cache);
216 match (&*l, &*r) {
217 (Expr::Const(a), Expr::Const(b)) if *b != 0.0 => Box::new(Expr::Const(a / b)),
218 _ => Box::new(Expr::Div(l, r)),
219 }
220 }
221
222 Expr::Abs(expr) => {
223 let e = expr.pre_evaluate(var_cache);
224 match &*e {
225 Expr::Const(a) => Box::new(Expr::Const(a.abs())),
226 _ => Box::new(Expr::Abs(e)),
227 }
228 }
229
230 Expr::Neg(expr) => {
231 let e = expr.pre_evaluate(var_cache);
232 match &*e {
233 Expr::Const(a) => Box::new(Expr::Const(-a)),
234 _ => Box::new(Expr::Neg(e)),
235 }
236 }
237
238 Expr::Pow(base, exp) => {
239 let b = base.pre_evaluate(var_cache);
240 match &*b {
241 Expr::Const(a) => Box::new(Expr::Const(a.powi(*exp as i32))),
242 _ => Box::new(Expr::Pow(b, *exp)),
243 }
244 }
245
246 Expr::PowFloat(base, exp) => {
247 let b = base.pre_evaluate(var_cache);
248 match &*b {
249 Expr::Const(a) => Box::new(Expr::Const(a.powf(*exp))),
250 _ => Box::new(Expr::PowFloat(b, *exp)),
251 }
252 }
253
254 Expr::PowExpr(base, exponent) => {
255 let b = base.pre_evaluate(var_cache);
256 let e = exponent.pre_evaluate(var_cache);
257 match (&*b, &*e) {
258 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a.powf(*b))),
259 _ => Box::new(Expr::PowExpr(b, e)),
260 }
261 }
262
263 Expr::Exp(expr) => {
264 let e = expr.pre_evaluate(var_cache);
265 match &*e {
266 Expr::Const(a) => Box::new(Expr::Const(a.exp())),
267 _ => Box::new(Expr::Exp(e)),
268 }
269 }
270
271 Expr::Ln(expr) => {
272 let e = expr.pre_evaluate(var_cache);
273 match &*e {
274 Expr::Const(a) if *a > 0.0 => Box::new(Expr::Const(a.ln())),
275 _ => Box::new(Expr::Ln(e)),
276 }
277 }
278
279 Expr::Sqrt(expr) => {
280 let e = expr.pre_evaluate(var_cache);
281 match &*e {
282 Expr::Const(a) if *a >= 0.0 => Box::new(Expr::Const(a.sqrt())),
283 _ => Box::new(Expr::Sqrt(e)),
284 }
285 }
286
287 Expr::Sin(expr) => {
288 let e = expr.pre_evaluate(var_cache);
289 match &*e {
290 Expr::Const(a) => Box::new(Expr::Const(a.sin())),
291 _ => Box::new(Expr::Sin(e)),
292 }
293 }
294
295 Expr::Cos(expr) => {
296 let e = expr.pre_evaluate(var_cache);
297 match &*e {
298 Expr::Const(a) => Box::new(Expr::Const(a.cos())),
299 _ => Box::new(Expr::Cos(e)),
300 }
301 }
302
303 Expr::Cached(expr, _) => expr.pre_evaluate(var_cache),
304 }
305 }
306
307 pub fn derivative(&self, with_respect_to: &str) -> Box<Expr> {
330 match self {
331 Expr::Const(_) => Box::new(Expr::Const(0.0)),
332
333 Expr::Var(var_ref) => {
334 if var_ref.name == with_respect_to {
335 Box::new(Expr::Const(1.0))
336 } else {
337 Box::new(Expr::Const(0.0))
338 }
339 }
340
341 Expr::Add(left, right) => {
342 Box::new(Expr::Add(
344 left.derivative(with_respect_to),
345 right.derivative(with_respect_to),
346 ))
347 }
348
349 Expr::Sub(left, right) => {
350 Box::new(Expr::Sub(
352 left.derivative(with_respect_to),
353 right.derivative(with_respect_to),
354 ))
355 }
356
357 Expr::Mul(left, right) => {
358 Box::new(Expr::Add(
360 Box::new(Expr::Mul(left.clone(), right.derivative(with_respect_to))),
361 Box::new(Expr::Mul(right.clone(), left.derivative(with_respect_to))),
362 ))
363 }
364
365 Expr::Div(left, right) => {
366 Box::new(Expr::Div(
368 Box::new(Expr::Sub(
369 Box::new(Expr::Mul(right.clone(), left.derivative(with_respect_to))),
370 Box::new(Expr::Mul(left.clone(), right.derivative(with_respect_to))),
371 )),
372 Box::new(Expr::Pow(right.clone(), 2)),
373 ))
374 }
375
376 Expr::Abs(expr) => {
377 Box::new(Expr::Mul(
379 Box::new(Expr::Div(expr.clone(), Box::new(Expr::Abs(expr.clone())))),
380 expr.derivative(with_respect_to),
381 ))
382 }
383
384 Expr::Pow(base, exp) => {
385 Box::new(Expr::Mul(
387 Box::new(Expr::Mul(
388 Box::new(Expr::Const(*exp as f64)),
389 Box::new(Expr::Pow(base.clone(), exp - 1)),
390 )),
391 base.derivative(with_respect_to),
392 ))
393 }
394
395 Expr::PowFloat(base, exp) => {
396 Box::new(Expr::Mul(
398 Box::new(Expr::Mul(
399 Box::new(Expr::Const(*exp)),
400 Box::new(Expr::PowFloat(base.clone(), exp - 1.0)),
401 )),
402 base.derivative(with_respect_to),
403 ))
404 }
405
406 Expr::PowExpr(base, exponent) => {
407 Box::new(Expr::Mul(
410 Box::new(Expr::PowExpr(base.clone(), exponent.clone())),
411 Box::new(Expr::Add(
412 Box::new(Expr::Mul(
413 exponent.derivative(with_respect_to),
414 Box::new(Expr::Ln(base.clone())),
415 )),
416 Box::new(Expr::Mul(
417 exponent.clone(),
418 Box::new(Expr::Div(base.derivative(with_respect_to), base.clone())),
419 )),
420 )),
421 ))
422 }
423
424 Expr::Exp(expr) => {
425 Box::new(Expr::Mul(
427 Box::new(Expr::Exp(expr.clone())),
428 expr.derivative(with_respect_to),
429 ))
430 }
431
432 Expr::Ln(expr) => {
433 Box::new(Expr::Mul(
435 Box::new(Expr::Div(Box::new(Expr::Const(1.0)), expr.clone())),
436 expr.derivative(with_respect_to),
437 ))
438 }
439
440 Expr::Sqrt(expr) => {
441 Box::new(Expr::Mul(
443 Box::new(Expr::Div(
444 Box::new(Expr::Const(1.0)),
445 Box::new(Expr::Mul(
446 Box::new(Expr::Const(2.0)),
447 Box::new(Expr::Sqrt(expr.clone())),
448 )),
449 )),
450 expr.derivative(with_respect_to),
451 ))
452 }
453
454 Expr::Sin(expr) => {
455 Box::new(Expr::Mul(
457 Box::new(Expr::Cos(expr.clone())),
458 expr.derivative(with_respect_to),
459 ))
460 }
461
462 Expr::Cos(expr) => {
463 Box::new(Expr::Mul(
465 Box::new(Expr::Neg(Box::new(Expr::Sin(expr.clone())))),
466 expr.derivative(with_respect_to),
467 ))
468 }
469
470 Expr::Neg(expr) => {
471 Box::new(Expr::Neg(expr.derivative(with_respect_to)))
473 }
474
475 Expr::Cached(expr, _) => expr.derivative(with_respect_to),
476 }
477 }
478
479 pub fn simplify(&self) -> Box<Expr> {
510 match self {
511 Expr::Const(_) | Expr::Var(_) => Box::new(self.clone()),
513
514 Expr::Add(left, right) => {
515 let l = left.simplify();
516 let r = right.simplify();
517 match (&*l, &*r) {
518 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a + b)),
520 (expr, Expr::Const(0.0)) | (Expr::Const(0.0), expr) => Box::new(expr.clone()),
522 (Expr::Mul(a1, x1), Expr::Mul(a2, x2)) if x1 == x2 => {
524 let combined_coeff = Expr::Add(a1.clone(), a2.clone()).simplify();
525 Box::new(Expr::Mul(combined_coeff, x1.clone()))
526 }
527 (Expr::Add(x, c1), c2)
529 if matches!(**c1, Expr::Const(_)) && matches!(*c2, Expr::Const(_)) =>
530 {
531 Box::new(Expr::Add(
532 x.clone(),
533 Expr::Add(c1.clone(), Box::new(c2.clone())).simplify(),
534 ))
535 }
536 _ => Box::new(Expr::Add(l, r)),
537 }
538 }
539
540 Expr::Sub(left, right) => {
541 let l = left.simplify();
542 let r = right.simplify();
543 match (&*l, &*r) {
544 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a - b)),
546 (expr, Expr::Const(0.0)) => Box::new(expr.clone()),
548 (a, b) if a == b => Box::new(Expr::Const(0.0)),
550 (Expr::Mul(a1, x1), Expr::Mul(a2, x2)) if x1 == x2 => {
552 let combined_coeff = Expr::Sub(a1.clone(), a2.clone()).simplify();
553 Box::new(Expr::Mul(combined_coeff, x1.clone()))
554 }
555 (x, Expr::Const(c)) => {
557 Box::new(Expr::Add(Box::new(x.clone()), Box::new(Expr::Const(-c))))
558 }
559 _ => Box::new(Expr::Sub(l, r)),
560 }
561 }
562
563 Expr::Mul(left, right) => {
564 let l = left.simplify();
565 let r = right.simplify();
566
567 if l == r {
569 return Box::new(Expr::Pow(l, 2)); }
571
572 match (&*l, &*r) {
573 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a * b)),
575 (Expr::Const(0.0), _) | (_, Expr::Const(0.0)) => Box::new(Expr::Const(0.0)),
577 (expr, Expr::Const(1.0)) | (Expr::Const(1.0), expr) => Box::new(expr.clone()),
579 (expr, Expr::Const(-1.0)) | (Expr::Const(-1.0), expr) => {
581 Box::new(Expr::Neg(Box::new(expr.clone())))
582 }
583 (Expr::Pow(b1, e1), Expr::Pow(b2, e2)) if b1 == b2 => {
585 Box::new(Expr::Pow(b1.clone(), e1 + e2))
586 }
587 (Expr::Const(c), Expr::Add(x, y)) | (Expr::Add(x, y), Expr::Const(c))
589 if c.abs() < 10.0 =>
590 {
591 Expr::Add(
592 Box::new(Expr::Mul(Box::new(Expr::Const(*c)), x.clone())),
593 Box::new(Expr::Mul(Box::new(Expr::Const(*c)), y.clone())),
594 )
595 .simplify()
596 }
597 (expr, Expr::Const(2.0)) | (Expr::Const(2.0), expr) => {
599 Box::new(Expr::Add(Box::new(expr.clone()), Box::new(expr.clone())))
600 }
601 (Expr::Mul(c1, x), c2)
603 if matches!(**c1, Expr::Const(_)) && matches!(*c2, Expr::Const(_)) =>
604 {
605 Box::new(Expr::Mul(
606 Expr::Mul(c1.clone(), Box::new(c2.clone())).simplify(),
607 x.clone(),
608 ))
609 }
610 _ => Box::new(Expr::Mul(l, r)),
611 }
612 }
613
614 Expr::Div(left, right) => {
615 let l = left.simplify();
616 let r = right.simplify();
617 match (&*l, &*r) {
618 (Expr::Const(a), Expr::Const(b)) if *b != 0.0 => Box::new(Expr::Const(a / b)),
620 (Expr::Const(0.0), _) => Box::new(Expr::Const(0.0)),
622 (expr, Expr::Const(1.0)) => Box::new(expr.clone()),
624 (expr, Expr::Const(-1.0)) => Box::new(Expr::Neg(Box::new(expr.clone()))),
626 (a, b) if a == b => Box::new(Expr::Const(1.0)),
628 (Expr::Pow(b1, e1), Expr::Pow(b2, e2)) if b1 == b2 => {
630 Box::new(Expr::Pow(b1.clone(), e1 - e2))
631 }
632 (x, Expr::Const(c)) if *c != 0.0 && c.abs() > 1e-10 => Box::new(Expr::Mul(
634 Box::new(x.clone()),
635 Box::new(Expr::Const(1.0 / c)),
636 )),
637 (Expr::Div(x, y), z) => Box::new(Expr::Div(
639 x.clone(),
640 Box::new(Expr::Mul(y.clone(), Box::new(z.clone()))),
641 )),
642 _ => Box::new(Expr::Div(l, r)),
643 }
644 }
645
646 Expr::Abs(expr) => {
647 let e = expr.simplify();
648 match &*e {
649 Expr::Const(a) => Box::new(Expr::Const(a.abs())),
651 Expr::Abs(inner) => Box::new(Expr::Abs(inner.clone())),
653 Expr::Neg(inner) => Box::new(Expr::Abs(inner.clone())),
655 Expr::Pow(_, exp) if exp % 2 == 0 => e,
657 _ => Box::new(Expr::Abs(e)),
658 }
659 }
660
661 Expr::Pow(base, exp) => {
662 let b = base.simplify();
663 match (&*b, exp) {
664 (_, 0) => Box::new(Expr::Const(1.0)),
666 (Expr::Const(a), exp) => Box::new(Expr::Const(a.powi(*exp as i32))),
668 (expr, 1) => Box::new(expr.clone()),
670 (expr, exp) if *exp < 0 => Box::new(Expr::Div(
672 Box::new(Expr::Const(1.0)),
673 Box::new(Expr::Pow(Box::new(expr.clone()), -exp)),
674 )),
675 (Expr::Pow(inner_base, inner_exp), outer_exp) => {
677 Box::new(Expr::Pow(inner_base.clone(), inner_exp * outer_exp))
678 }
679 (Expr::Mul(x, y), n) if *n >= 2 && *n <= 4 => Box::new(Expr::Mul(
681 Box::new(Expr::Pow(x.clone(), *n)),
682 Box::new(Expr::Pow(y.clone(), *n)),
683 )),
684 _ => Box::new(Expr::Pow(b, *exp)),
685 }
686 }
687
688 Expr::PowFloat(base, exp) => {
689 let b = base.simplify();
690 match (&*b, exp) {
691 (_, exp) if exp.abs() < 1e-10 => Box::new(Expr::Const(1.0)),
693 (Expr::Const(a), exp) => Box::new(Expr::Const(a.powf(*exp))),
695 (expr, exp) if (exp - 1.0).abs() < 1e-10 => Box::new(expr.clone()),
697 (expr, exp) if exp.fract().abs() < 1e-10 => {
699 Box::new(Expr::Pow(Box::new(expr.clone()), *exp as i64))
700 }
701 _ => Box::new(Expr::PowFloat(b, *exp)),
702 }
703 }
704
705 Expr::PowExpr(base, exponent) => {
706 let b = base.simplify();
707 let e = exponent.simplify();
708 match (&*b, &*e) {
709 (Expr::Const(a), Expr::Const(b)) => Box::new(Expr::Const(a.powf(*b))),
711 (_, Expr::Const(0.0)) => Box::new(Expr::Const(1.0)),
713 (expr, Expr::Const(1.0)) => Box::new(expr.clone()),
715 (expr, Expr::Const(exp)) if exp.fract().abs() < 1e-10 => {
717 Box::new(Expr::Pow(Box::new(expr.clone()), *exp as i64))
718 }
719 (expr, Expr::Const(exp)) => {
720 Box::new(Expr::PowFloat(Box::new(expr.clone()), *exp))
721 }
722 _ => Box::new(Expr::PowExpr(b, e)),
723 }
724 }
725
726 Expr::Exp(expr) => {
727 let e = expr.simplify();
728 match &*e {
729 Expr::Const(0.0) => Box::new(Expr::Const(1.0)),
731 Expr::Const(a) => Box::new(Expr::Const(a.exp())),
733 Expr::Ln(inner) => inner.clone(),
735 Expr::Add(x, y) => Box::new(Expr::Mul(
737 Box::new(Expr::Exp(x.clone())),
738 Box::new(Expr::Exp(y.clone())),
739 )),
740 _ => Box::new(Expr::Exp(e)),
741 }
742 }
743
744 Expr::Ln(expr) => {
745 let e = expr.simplify();
746 match &*e {
747 Expr::Const(a) if *a > 0.0 => Box::new(Expr::Const(a.ln())),
749 Expr::Const(1.0) => Box::new(Expr::Const(0.0)),
751 Expr::Exp(inner) => inner.clone(),
753 Expr::Mul(x, y) => Box::new(Expr::Add(
755 Box::new(Expr::Ln(x.clone())),
756 Box::new(Expr::Ln(y.clone())),
757 )),
758 Expr::Div(x, y) => Box::new(Expr::Sub(
760 Box::new(Expr::Ln(x.clone())),
761 Box::new(Expr::Ln(y.clone())),
762 )),
763 Expr::Pow(x, n) => Box::new(Expr::Mul(
765 Box::new(Expr::Const(*n as f64)),
766 Box::new(Expr::Ln(x.clone())),
767 )),
768 _ => Box::new(Expr::Ln(e)),
769 }
770 }
771
772 Expr::Sqrt(expr) => {
773 let e = expr.simplify();
774 match &*e {
775 Expr::Const(a) if *a >= 0.0 => Box::new(Expr::Const(a.sqrt())),
777 Expr::Const(0.0) => Box::new(Expr::Const(0.0)),
779 Expr::Const(1.0) => Box::new(Expr::Const(1.0)),
781 Expr::Pow(x, 2) => Box::new(Expr::Abs(x.clone())),
783 Expr::Mul(x, y) => Box::new(Expr::Mul(
785 Box::new(Expr::Sqrt(x.clone())),
786 Box::new(Expr::Sqrt(y.clone())),
787 )),
788 _ => Box::new(Expr::Sqrt(e)),
789 }
790 }
791
792 Expr::Sin(expr) => {
793 let e = expr.simplify();
794 match &*e {
795 Expr::Const(0.0) => Box::new(Expr::Const(0.0)),
797 Expr::Const(a) => Box::new(Expr::Const(a.sin())),
799 _ => Box::new(Expr::Sin(e)),
800 }
801 }
802
803 Expr::Cos(expr) => {
804 let e = expr.simplify();
805 match &*e {
806 Expr::Const(0.0) => Box::new(Expr::Const(1.0)),
808 Expr::Const(a) => Box::new(Expr::Const(a.cos())),
810 _ => Box::new(Expr::Cos(e)),
811 }
812 }
813
814 Expr::Neg(expr) => {
815 let e = expr.simplify();
816 match &*e {
817 Expr::Const(a) => Box::new(Expr::Const(-a)),
819 Expr::Neg(inner) => inner.clone(),
821 Expr::Add(x, y) => {
823 Expr::Sub(Box::new(Expr::Neg(x.clone())), y.clone()).simplify()
824 }
825 Expr::Sub(x, y) => {
827 Expr::Add(Box::new(Expr::Neg(x.clone())), y.clone()).simplify()
828 }
829 Expr::Mul(c, x) if matches!(**c, Expr::Const(_)) => {
831 Expr::Mul(Box::new(Expr::Neg(c.clone())), x.clone()).simplify()
832 }
833 _ => Box::new(Expr::Neg(e)),
834 }
835 }
836
837 Expr::Cached(expr, cached_value) => {
838 if cached_value.is_some() {
839 Box::new(self.clone())
840 } else {
841 expr.simplify()
843 }
844 }
845 }
846 }
847
848 pub fn insert<F>(&self, predicate: F, replacement: &Expr) -> Box<Expr>
861 where
862 F: Fn(&Expr) -> bool + Clone,
863 {
864 if predicate(self) {
865 Box::new(replacement.clone())
866 } else {
867 match self {
868 Expr::Const(_) | Expr::Var(_) => Box::new(self.clone()),
869 Expr::Add(left, right) => Box::new(Expr::Add(
870 left.insert(predicate.clone(), replacement),
871 right.insert(predicate, replacement),
872 )),
873 Expr::Mul(left, right) => Box::new(Expr::Mul(
874 left.insert(predicate.clone(), replacement),
875 right.insert(predicate, replacement),
876 )),
877 Expr::Sub(left, right) => Box::new(Expr::Sub(
878 left.insert(predicate.clone(), replacement),
879 right.insert(predicate, replacement),
880 )),
881 Expr::Div(left, right) => Box::new(Expr::Div(
882 left.insert(predicate.clone(), replacement),
883 right.insert(predicate, replacement),
884 )),
885 Expr::Abs(expr) => Box::new(Expr::Abs(expr.insert(predicate, replacement))),
886 Expr::Pow(base, exp) => {
887 Box::new(Expr::Pow(base.insert(predicate, replacement), *exp))
888 }
889 Expr::PowFloat(base, exp) => {
890 Box::new(Expr::PowFloat(base.insert(predicate, replacement), *exp))
891 }
892 Expr::PowExpr(base, exponent) => Box::new(Expr::PowExpr(
893 base.insert(predicate.clone(), replacement),
894 exponent.insert(predicate, replacement),
895 )),
896 Expr::Exp(expr) => Box::new(Expr::Exp(expr.insert(predicate, replacement))),
897 Expr::Ln(expr) => Box::new(Expr::Ln(expr.insert(predicate, replacement))),
898 Expr::Sqrt(expr) => Box::new(Expr::Sqrt(expr.insert(predicate, replacement))),
899 Expr::Sin(expr) => Box::new(Expr::Sin(expr.insert(predicate, replacement))),
900 Expr::Cos(expr) => Box::new(Expr::Cos(expr.insert(predicate, replacement))),
901 Expr::Neg(expr) => Box::new(Expr::Neg(expr.insert(predicate, replacement))),
902 Expr::Cached(expr, _) => {
903 Box::new(Expr::Cached(expr.insert(predicate, replacement), None))
904 }
905 }
906 }
907 }
908
909 pub fn flatten(&self) -> FlattenedExpr {
920 let mut ops = Vec::new();
921 let mut max_var_index = None;
922
923 if let Some(constant) = self.try_evaluate_constant() {
925 return FlattenedExpr {
926 ops: vec![LinearOp::LoadConst(constant)],
927 max_var_index: None,
928 constant_result: Some(constant),
929 };
930 }
931
932 self.flatten_recursive(&mut ops, &mut max_var_index);
933
934 FlattenedExpr {
935 ops,
936 max_var_index,
937 constant_result: None,
938 }
939 }
940
941 fn try_evaluate_constant(&self) -> Option<f64> {
943 match self {
944 Expr::Const(val) => Some(*val),
945 Expr::Var(_) => None,
946 Expr::Add(left, right) => {
947 Some(left.try_evaluate_constant()? + right.try_evaluate_constant()?)
948 }
949 Expr::Sub(left, right) => {
950 Some(left.try_evaluate_constant()? - right.try_evaluate_constant()?)
951 }
952 Expr::Mul(left, right) => {
953 Some(left.try_evaluate_constant()? * right.try_evaluate_constant()?)
954 }
955 Expr::Div(left, right) => {
956 let r = right.try_evaluate_constant()?;
957 if r.abs() < 1e-10 {
958 return None;
959 }
960 Some(left.try_evaluate_constant()? / r)
961 }
962 Expr::Abs(expr) => Some(expr.try_evaluate_constant()?.abs()),
963 Expr::Neg(expr) => Some(-expr.try_evaluate_constant()?),
964 Expr::Pow(base, exp) => Some(base.try_evaluate_constant()?.powi(*exp as i32)),
965 Expr::PowFloat(base, exp) => Some(base.try_evaluate_constant()?.powf(*exp)),
966 Expr::PowExpr(base, exponent) => Some(
967 base.try_evaluate_constant()?
968 .powf(exponent.try_evaluate_constant()?),
969 ),
970 Expr::Exp(expr) => Some(expr.try_evaluate_constant()?.exp()),
971 Expr::Ln(expr) => {
972 let val = expr.try_evaluate_constant()?;
973 if val <= 0.0 {
974 return None;
975 }
976 Some(val.ln())
977 }
978 Expr::Sqrt(expr) => {
979 let val = expr.try_evaluate_constant()?;
980 if val < 0.0 {
981 return None;
982 }
983 Some(val.sqrt())
984 }
985 Expr::Sin(expr) => Some(expr.try_evaluate_constant()?.sin()),
986 Expr::Cos(expr) => Some(expr.try_evaluate_constant()?.cos()),
987 Expr::Cached(expr, cached_value) => {
988 cached_value.or_else(|| expr.try_evaluate_constant())
989 }
990 }
991 }
992
993 fn flatten_recursive(&self, ops: &mut Vec<LinearOp>, max_var_index: &mut Option<u32>) {
995 match self {
996 Expr::Const(val) => {
997 ops.push(LinearOp::LoadConst(*val));
998 }
999
1000 Expr::Var(var_ref) => {
1001 let index = var_ref.index;
1002 *max_var_index = Some(max_var_index.unwrap_or(0).max(index));
1003 ops.push(LinearOp::LoadVar(index));
1004 }
1005
1006 Expr::Add(left, right) => {
1007 left.flatten_recursive(ops, max_var_index);
1008 right.flatten_recursive(ops, max_var_index);
1009 ops.push(LinearOp::Add);
1010 }
1011
1012 Expr::Sub(left, right) => {
1013 left.flatten_recursive(ops, max_var_index);
1014 right.flatten_recursive(ops, max_var_index);
1015 ops.push(LinearOp::Sub);
1016 }
1017
1018 Expr::Mul(left, right) => {
1019 left.flatten_recursive(ops, max_var_index);
1020 right.flatten_recursive(ops, max_var_index);
1021 ops.push(LinearOp::Mul);
1022 }
1023
1024 Expr::Div(left, right) => {
1025 left.flatten_recursive(ops, max_var_index);
1026 right.flatten_recursive(ops, max_var_index);
1027 ops.push(LinearOp::Div);
1028 }
1029
1030 Expr::Abs(expr) => {
1031 expr.flatten_recursive(ops, max_var_index);
1032 ops.push(LinearOp::Abs);
1033 }
1034
1035 Expr::Neg(expr) => {
1036 expr.flatten_recursive(ops, max_var_index);
1037 ops.push(LinearOp::Neg);
1038 }
1039
1040 Expr::Pow(base, exp) => {
1041 base.flatten_recursive(ops, max_var_index);
1042 ops.push(LinearOp::PowConst(*exp));
1043 }
1044
1045 Expr::PowFloat(base, exp) => {
1046 base.flatten_recursive(ops, max_var_index);
1047 ops.push(LinearOp::PowFloat(*exp));
1048 }
1049
1050 Expr::PowExpr(base, exponent) => {
1051 base.flatten_recursive(ops, max_var_index);
1052 exponent.flatten_recursive(ops, max_var_index);
1053 ops.push(LinearOp::PowExpr);
1054 }
1055
1056 Expr::Exp(expr) => {
1057 expr.flatten_recursive(ops, max_var_index);
1058 ops.push(LinearOp::Exp);
1059 }
1060
1061 Expr::Ln(expr) => {
1062 expr.flatten_recursive(ops, max_var_index);
1063 ops.push(LinearOp::Ln);
1064 }
1065
1066 Expr::Sqrt(expr) => {
1067 expr.flatten_recursive(ops, max_var_index);
1068 ops.push(LinearOp::Sqrt);
1069 }
1070
1071 Expr::Sin(expr) => {
1072 expr.flatten_recursive(ops, max_var_index);
1073 ops.push(LinearOp::Sin);
1074 }
1075
1076 Expr::Cos(expr) => {
1077 expr.flatten_recursive(ops, max_var_index);
1078 ops.push(LinearOp::Cos);
1079 }
1080
1081 Expr::Cached(expr, cached_value) => {
1082 if let Some(val) = cached_value {
1083 ops.push(LinearOp::LoadConst(*val));
1084 } else {
1085 expr.flatten_recursive(ops, max_var_index);
1086 }
1087 }
1088 }
1089 }
1090
1091 pub fn codegen_flattened(
1096 &self,
1097 builder: &mut FunctionBuilder,
1098 module: &mut dyn Module,
1099 ) -> Result<Value, EquationError> {
1100 let flattened = self.flatten();
1101
1102 if let Some(constant) = flattened.constant_result {
1104 return Ok(builder.ins().f64const(constant));
1105 }
1106
1107 let mut value_stack = Vec::with_capacity(flattened.ops.len());
1109
1110 let input_ptr = builder
1112 .func
1113 .dfg
1114 .block_params(builder.current_block().unwrap())[0];
1115
1116 for op in &flattened.ops {
1118 match op {
1119 LinearOp::LoadConst(val) => {
1120 value_stack.push(builder.ins().f64const(*val));
1121 }
1122
1123 LinearOp::LoadVar(index) => {
1124 let offset = (*index as i32) * 8;
1125 let memflags = MemFlags::new().with_aligned().with_readonly().with_notrap();
1126 let val =
1127 builder
1128 .ins()
1129 .load(types::F64, memflags, input_ptr, Offset32::new(offset));
1130 value_stack.push(val);
1131 }
1132
1133 LinearOp::Add => {
1134 let rhs = value_stack.pop().unwrap();
1135 let lhs = value_stack.pop().unwrap();
1136 value_stack.push(builder.ins().fadd(lhs, rhs));
1137 }
1138
1139 LinearOp::Sub => {
1140 let rhs = value_stack.pop().unwrap();
1141 let lhs = value_stack.pop().unwrap();
1142 value_stack.push(builder.ins().fsub(lhs, rhs));
1143 }
1144
1145 LinearOp::Mul => {
1146 let rhs = value_stack.pop().unwrap();
1147 let lhs = value_stack.pop().unwrap();
1148 value_stack.push(builder.ins().fmul(lhs, rhs));
1149 }
1150
1151 LinearOp::Div => {
1152 let rhs = value_stack.pop().unwrap();
1153 let lhs = value_stack.pop().unwrap();
1154 value_stack.push(builder.ins().fdiv(lhs, rhs));
1155 }
1156
1157 LinearOp::Abs => {
1158 let val = value_stack.pop().unwrap();
1159 value_stack.push(builder.ins().fabs(val));
1160 }
1161
1162 LinearOp::Neg => {
1163 let val = value_stack.pop().unwrap();
1164 value_stack.push(builder.ins().fneg(val));
1165 }
1166
1167 LinearOp::PowConst(exp) => {
1168 let base = value_stack.pop().unwrap();
1169 let result = match *exp {
1170 0 => builder.ins().f64const(1.0),
1171 1 => base,
1172 2 => builder.ins().fmul(base, base),
1173 3 => {
1174 let square = builder.ins().fmul(base, base);
1175 builder.ins().fmul(square, base)
1176 }
1177 4 => {
1178 let square = builder.ins().fmul(base, base);
1179 builder.ins().fmul(square, square)
1180 }
1181 -1 => {
1182 let one = builder.ins().f64const(1.0);
1183 builder.ins().fdiv(one, base)
1184 }
1185 -2 => {
1186 let square = builder.ins().fmul(base, base);
1187 let one = builder.ins().f64const(1.0);
1188 builder.ins().fdiv(one, square)
1189 }
1190 _ => {
1191 generate_optimized_power(builder, base, *exp)
1193 }
1194 };
1195 value_stack.push(result);
1196 }
1197
1198 LinearOp::PowFloat(exp) => {
1199 let base = value_stack.pop().unwrap();
1200 let func_id = crate::operators::pow::link_powf(module).unwrap();
1201 let exp_val = builder.ins().f64const(*exp);
1202 let result =
1203 crate::operators::pow::call_powf(builder, module, func_id, base, exp_val);
1204 value_stack.push(result);
1205 }
1206
1207 LinearOp::PowExpr => {
1208 let exponent = value_stack.pop().unwrap();
1209 let base = value_stack.pop().unwrap();
1210 let func_id = crate::operators::pow::link_powf(module).unwrap();
1211 let result =
1212 crate::operators::pow::call_powf(builder, module, func_id, base, exponent);
1213 value_stack.push(result);
1214 }
1215
1216 LinearOp::Exp => {
1217 let arg = value_stack.pop().unwrap();
1218 let func_id = operators::exp::link_exp(module).unwrap();
1219 let result = operators::exp::call_exp(builder, module, func_id, arg);
1220 value_stack.push(result);
1221 }
1222
1223 LinearOp::Ln => {
1224 let arg = value_stack.pop().unwrap();
1225 let func_id = operators::ln::link_ln(module).unwrap();
1226 let result = operators::ln::call_ln(builder, module, func_id, arg);
1227 value_stack.push(result);
1228 }
1229
1230 LinearOp::Sqrt => {
1231 let arg = value_stack.pop().unwrap();
1232 let func_id = operators::sqrt::link_sqrt(module).unwrap();
1233 let result = operators::sqrt::call_sqrt(builder, module, func_id, arg);
1234 value_stack.push(result);
1235 }
1236
1237 LinearOp::Sin => {
1238 let arg = value_stack.pop().unwrap();
1239 let func_id = crate::operators::trigonometric::link_sin(module).unwrap();
1240 let result =
1241 crate::operators::trigonometric::call_sin(builder, module, func_id, arg);
1242 value_stack.push(result);
1243 }
1244
1245 LinearOp::Cos => {
1246 let arg = value_stack.pop().unwrap();
1247 let func_id = crate::operators::trigonometric::link_cos(module).unwrap();
1248 let result =
1249 crate::operators::trigonometric::call_cos(builder, module, func_id, arg);
1250 value_stack.push(result);
1251 }
1252 }
1253 }
1254
1255 Ok(value_stack.pop().unwrap())
1257 }
1258}
1259
1260fn generate_optimized_power(builder: &mut FunctionBuilder, base: Value, exp: i64) -> Value {
1262 if exp == 0 {
1263 return builder.ins().f64const(1.0);
1264 }
1265
1266 if exp == 1 {
1267 return base;
1268 }
1269
1270 let abs_exp = exp.abs();
1271 let mut result = builder.ins().f64const(1.0);
1272 let mut current_base = base;
1273 let mut remaining = abs_exp;
1274
1275 while remaining > 0 {
1277 if remaining & 1 == 1 {
1278 result = builder.ins().fmul(result, current_base);
1279 }
1280 if remaining > 1 {
1281 current_base = builder.ins().fmul(current_base, current_base);
1282 }
1283 remaining >>= 1;
1284 }
1285
1286 if exp < 0 {
1287 let one = builder.ins().f64const(1.0);
1288 builder.ins().fdiv(one, result)
1289 } else {
1290 result
1291 }
1292}
1293
1294impl std::fmt::Display for Expr {
1306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1307 match self {
1308 Expr::Const(val) => write!(f, "{val}"),
1309 Expr::Var(var_ref) => write!(f, "{0}", var_ref.name),
1310 Expr::Add(left, right) => write!(f, "({left} + {right})"),
1311 Expr::Mul(left, right) => write!(f, "({left} * {right})"),
1312 Expr::Sub(left, right) => write!(f, "({left} - {right})"),
1313 Expr::Div(left, right) => write!(f, "({left} / {right})"),
1314 Expr::Abs(expr) => write!(f, "|{expr}|"),
1315 Expr::Pow(base, exp) => write!(f, "({base}^{exp})"),
1316 Expr::PowFloat(base, exp) => write!(f, "({base}^{exp})"),
1317 Expr::PowExpr(base, exponent) => write!(f, "({base}^{exponent})"),
1318 Expr::Exp(expr) => write!(f, "exp({expr})"),
1319 Expr::Ln(expr) => write!(f, "ln({expr})"),
1320 Expr::Sqrt(expr) => write!(f, "sqrt({expr})"),
1321 Expr::Sin(expr) => write!(f, "sin({expr})"),
1322 Expr::Cos(expr) => write!(f, "cos({expr})"),
1323 Expr::Neg(expr) => write!(f, "-({expr})"),
1324 Expr::Cached(expr, _) => write!(f, "{expr}"),
1325 }
1326 }
1327}
1328
1329#[cfg(test)]
1330mod tests {
1331 use super::*;
1332
1333 fn var(name: &str) -> Box<Expr> {
1335 Box::new(Expr::Var(VarRef {
1336 name: name.to_string(),
1337 vec_ref: Value::from_u32(0),
1338 index: 0,
1339 }))
1340 }
1341
1342 #[test]
1343 fn test_simplify() {
1344 fn var(name: &str) -> Box<Expr> {
1346 Box::new(Expr::Var(VarRef {
1347 name: name.to_string(),
1348 vec_ref: Value::from_u32(0), index: 0,
1350 }))
1351 }
1352
1353 assert_eq!(
1356 *Expr::Add(Box::new(Expr::Const(2.0)), Box::new(Expr::Const(3.0))).simplify(),
1357 Expr::Const(5.0)
1358 );
1359
1360 assert_eq!(
1363 *Expr::Add(var("x"), Box::new(Expr::Const(0.0))).simplify(),
1364 *var("x")
1365 );
1366
1367 assert_eq!(
1370 *Expr::Mul(var("x"), Box::new(Expr::Const(1.0))).simplify(),
1371 *var("x")
1372 );
1373
1374 assert_eq!(
1377 *Expr::Mul(var("x"), Box::new(Expr::Const(0.0))).simplify(),
1378 Expr::Const(0.0)
1379 );
1380
1381 assert_eq!(
1384 *Expr::Div(var("x"), Box::new(Expr::Const(1.0))).simplify(),
1385 *var("x")
1386 );
1387
1388 assert_eq!(*Expr::Div(var("x"), var("x")).simplify(), Expr::Const(1.0));
1391
1392 assert_eq!(*Expr::Pow(var("x"), 0).simplify(), Expr::Const(1.0));
1395 assert_eq!(*Expr::Pow(var("x"), 1).simplify(), *var("x"));
1397
1398 assert_eq!(
1401 *Expr::Abs(Box::new(Expr::Const(-3.0))).simplify(),
1402 Expr::Const(3.0)
1403 );
1404
1405 assert_eq!(
1408 *Expr::Abs(Box::new(Expr::Abs(var("x")))).simplify(),
1409 Expr::Abs(var("x"))
1410 );
1411 }
1412
1413 #[test]
1414 fn test_insert() {
1415 fn var(name: &str) -> Box<Expr> {
1417 Box::new(Expr::Var(VarRef {
1418 name: name.to_string(),
1419 vec_ref: Value::from_u32(0),
1420 index: 0,
1421 }))
1422 }
1423
1424 let expr = Box::new(Expr::Add(var("x"), var("y")));
1426
1427 let replacement = Box::new(Expr::Mul(Box::new(Expr::Const(2.0)), var("z")));
1429
1430 let result = expr.insert(|e| matches!(e, Expr::Var(v) if v.name == "x"), &replacement);
1431
1432 assert_eq!(
1434 *result,
1435 Expr::Add(
1436 Box::new(Expr::Mul(Box::new(Expr::Const(2.0)), var("z"),)),
1437 var("y"),
1438 )
1439 );
1440 }
1441
1442 #[test]
1443 fn test_derivative() {
1444 assert_eq!(*Expr::Const(5.0).derivative("x"), Expr::Const(0.0));
1446
1447 assert_eq!(*var("x").derivative("x"), Expr::Const(1.0));
1449 assert_eq!(*var("y").derivative("x"), Expr::Const(0.0));
1450
1451 let sum = Box::new(Expr::Add(var("x"), var("y")));
1453 assert_eq!(
1454 *sum.derivative("x"),
1455 Expr::Add(Box::new(Expr::Const(1.0)), Box::new(Expr::Const(0.0)))
1456 );
1457
1458 let product = Box::new(Expr::Mul(var("x"), var("y")));
1460 assert_eq!(
1461 *product.derivative("x"),
1462 Expr::Add(
1463 Box::new(Expr::Mul(var("x"), Box::new(Expr::Const(0.0)))),
1464 Box::new(Expr::Mul(var("y"), Box::new(Expr::Const(1.0))))
1465 )
1466 );
1467
1468 let power = Box::new(Expr::Pow(var("x"), 3));
1470 assert_eq!(
1471 *power.derivative("x"),
1472 Expr::Mul(
1473 Box::new(Expr::Mul(
1474 Box::new(Expr::Const(3.0)),
1475 Box::new(Expr::Pow(var("x"), 2))
1476 )),
1477 Box::new(Expr::Const(1.0))
1478 )
1479 );
1480 }
1481
1482 #[test]
1483 fn test_complex_simplifications() {
1484 let expr = Box::new(Expr::Mul(
1486 Box::new(Expr::Add(var("x"), Box::new(Expr::Const(0.0)))),
1487 Box::new(Expr::Add(var("y"), Box::new(Expr::Const(0.0)))),
1488 ));
1489 assert_eq!(*expr.simplify(), Expr::Mul(var("x"), var("y")));
1490
1491 let expr = Box::new(Expr::Neg(Box::new(Expr::Neg(var("x")))));
1493 assert_eq!(*expr.simplify(), *var("x"));
1494
1495 let expr = Box::new(Expr::Mul(
1497 Box::new(Expr::Mul(Box::new(Expr::Const(1.0)), var("x"))),
1498 Box::new(Expr::Mul(var("y"), Box::new(Expr::Const(1.0)))),
1499 ));
1500 assert_eq!(*expr.simplify(), Expr::Mul(var("x"), var("y")));
1501
1502 let div = Box::new(Expr::Div(var("x"), var("y")));
1504 let expr = Box::new(Expr::Div(div.clone(), div));
1505 assert_eq!(*expr.simplify(), Expr::Const(1.0));
1506 }
1507
1508 #[test]
1509 fn test_special_functions() {
1510 let expr = Box::new(Expr::Abs(Box::new(Expr::Abs(var("x")))));
1512 assert_eq!(*expr.simplify(), Expr::Abs(var("x")));
1513
1514 let expr = Box::new(Expr::Sqrt(Box::new(Expr::Pow(var("x"), 2))));
1516 assert_eq!(*expr.simplify(), Expr::Abs(var("x")));
1517
1518 assert_eq!(
1521 *Expr::Exp(Box::new(Expr::Const(0.0))).simplify(),
1522 Expr::Const(1.0)
1523 );
1524 assert_eq!(
1526 *Expr::Ln(Box::new(Expr::Const(1.0))).simplify(),
1527 Expr::Const(0.0)
1528 );
1529 }
1530
1531 #[test]
1532 fn test_display() {
1533 assert_eq!(format!("{}", Expr::Const(5.0)), "5");
1535 assert_eq!(format!("{}", *var("x")), "x");
1536
1537 let sum = Expr::Add(var("x"), var("y"));
1539 assert_eq!(format!("{sum}"), "(x + y)");
1540
1541 let product = Expr::Mul(var("x"), var("y"));
1542 assert_eq!(format!("{product}"), "(x * y)");
1543
1544 let exp = Expr::Exp(var("x"));
1546 assert_eq!(format!("{exp}"), "exp(x)");
1547
1548 let abs = Expr::Abs(var("x"));
1549 assert_eq!(format!("{abs}"), "|x|");
1550
1551 let complex = Expr::Div(
1553 Box::new(Expr::Add(Box::new(Expr::Pow(var("x"), 2)), var("y"))),
1554 var("z"),
1555 );
1556 assert_eq!(format!("{complex}"), "(((x^2) + y) / z)");
1557 }
1558
1559 #[test]
1560 fn test_cached_expressions() {
1561 let cached = Box::new(Expr::Cached(Box::new(Expr::Const(5.0)), Some(5.0)));
1563 assert_eq!(*cached.simplify(), *cached);
1564
1565 let uncached = Box::new(Expr::Cached(
1567 Box::new(Expr::Add(
1568 Box::new(Expr::Const(2.0)),
1569 Box::new(Expr::Const(3.0)),
1570 )),
1571 None,
1572 ));
1573 assert_eq!(*uncached.simplify(), Expr::Const(5.0));
1574 }
1575}