1use std::collections::HashMap;
22
23#[derive(Debug, Clone, PartialEq)]
31pub enum Expr {
32 Const(f64),
34 Var(String),
36 Add(Box<Expr>, Box<Expr>),
38 Mul(Box<Expr>, Box<Expr>),
40 Pow(Box<Expr>, Box<Expr>),
42 Neg(Box<Expr>),
44 Sin(Box<Expr>),
46 Cos(Box<Expr>),
48 Exp(Box<Expr>),
50 Ln(Box<Expr>),
52}
53
54pub fn var(name: &str) -> Expr {
60 Expr::Var(name.to_string())
61}
62
63pub fn cst(v: f64) -> Expr {
65 Expr::Const(v)
66}
67
68impl Expr {
73 pub fn add_expr(self, rhs: Expr) -> Expr {
75 Expr::Add(Box::new(self), Box::new(rhs))
76 }
77
78 pub fn sub_expr(self, rhs: Expr) -> Expr {
80 Expr::Add(Box::new(self), Box::new(Expr::Neg(Box::new(rhs))))
81 }
82
83 pub fn mul_expr(self, rhs: Expr) -> Expr {
85 Expr::Mul(Box::new(self), Box::new(rhs))
86 }
87
88 pub fn pow(self, exp: Expr) -> Expr {
90 Expr::Pow(Box::new(self), Box::new(exp))
91 }
92
93 #[allow(clippy::should_implement_trait)]
95 pub fn neg(self) -> Expr {
96 Expr::Neg(Box::new(self))
97 }
98
99 pub fn sin(self) -> Expr {
101 Expr::Sin(Box::new(self))
102 }
103
104 pub fn cos(self) -> Expr {
106 Expr::Cos(Box::new(self))
107 }
108
109 pub fn exp(self) -> Expr {
111 Expr::Exp(Box::new(self))
112 }
113
114 pub fn ln(self) -> Expr {
116 Expr::Ln(Box::new(self))
117 }
118}
119
120pub fn eval(expr: &Expr, vars: &HashMap<String, f64>) -> Result<f64, String> {
129 match expr {
130 Expr::Const(c) => Ok(*c),
131 Expr::Var(name) => vars
132 .get(name)
133 .copied()
134 .ok_or_else(|| format!("undefined variable: {name}")),
135 Expr::Add(a, b) => Ok(eval(a, vars)? + eval(b, vars)?),
136 Expr::Mul(a, b) => Ok(eval(a, vars)? * eval(b, vars)?),
137 Expr::Pow(base, exp) => Ok(eval(base, vars)?.powf(eval(exp, vars)?)),
138 Expr::Neg(inner) => Ok(-eval(inner, vars)?),
139 Expr::Sin(inner) => Ok(eval(inner, vars)?.sin()),
140 Expr::Cos(inner) => Ok(eval(inner, vars)?.cos()),
141 Expr::Exp(inner) => Ok(eval(inner, vars)?.exp()),
142 Expr::Ln(inner) => {
143 let v = eval(inner, vars)?;
144 if v <= 0.0 {
145 Err(format!("ln of non-positive value: {v}"))
146 } else {
147 Ok(v.ln())
148 }
149 }
150 }
151}
152
153pub fn diff(expr: &Expr, var: &str) -> Expr {
161 match expr {
162 Expr::Const(_) => cst(0.0),
163 Expr::Var(name) => {
164 if name == var {
165 cst(1.0)
166 } else {
167 cst(0.0)
168 }
169 }
170 Expr::Add(f, g) => Expr::Add(Box::new(diff(f, var)), Box::new(diff(g, var))),
172 Expr::Mul(f, g) => Expr::Add(
174 Box::new(Expr::Mul(Box::new(diff(f, var)), g.clone())),
175 Box::new(Expr::Mul(f.clone(), Box::new(diff(g, var)))),
176 ),
177 Expr::Pow(base, exp) => {
179 if let Expr::Const(n) = exp.as_ref() {
182 let n = *n;
183 Expr::Mul(
185 Box::new(Expr::Mul(
186 Box::new(cst(n)),
187 Box::new(Expr::Pow(base.clone(), Box::new(cst(n - 1.0)))),
188 )),
189 Box::new(diff(base, var)),
190 )
191 } else {
192 let f = base.as_ref();
194 let g = exp.as_ref();
195 let fg = Expr::Pow(base.clone(), exp.clone());
196 let term1 = Expr::Mul(Box::new(diff(g, var)), Box::new(Expr::Ln(base.clone())));
197 let term2 = Expr::Mul(
198 g.clone().into(),
199 Box::new(Expr::Mul(
200 Box::new(diff(f, var)),
201 Box::new(Expr::Pow(base.clone(), Box::new(cst(-1.0)))),
202 )),
203 );
204 Expr::Mul(
205 Box::new(fg),
206 Box::new(Expr::Add(Box::new(term1), Box::new(term2))),
207 )
208 }
209 }
210 Expr::Neg(f) => Expr::Neg(Box::new(diff(f, var))),
212 Expr::Sin(f) => Expr::Mul(Box::new(Expr::Cos(f.clone())), Box::new(diff(f, var))),
214 Expr::Cos(f) => Expr::Neg(Box::new(Expr::Mul(
216 Box::new(Expr::Sin(f.clone())),
217 Box::new(diff(f, var)),
218 ))),
219 Expr::Exp(f) => Expr::Mul(Box::new(Expr::Exp(f.clone())), Box::new(diff(f, var))),
221 Expr::Ln(f) => Expr::Mul(
223 Box::new(diff(f, var)),
224 Box::new(Expr::Pow(f.clone(), Box::new(cst(-1.0)))),
225 ),
226 }
227}
228
229pub fn simplify(expr: &Expr) -> Expr {
243 match expr {
244 Expr::Const(_) | Expr::Var(_) => expr.clone(),
246
247 Expr::Add(a, b) => {
248 let a = simplify(a);
249 let b = simplify(b);
250 if let (Expr::Const(x), Expr::Const(y)) = (&a, &b) {
252 return cst(x + y);
253 }
254 if matches!(a, Expr::Const(x) if x == 0.0) {
256 return b;
257 }
258 if matches!(b, Expr::Const(x) if x == 0.0) {
260 return a;
261 }
262 Expr::Add(Box::new(a), Box::new(b))
263 }
264
265 Expr::Mul(a, b) => {
266 let a = simplify(a);
267 let b = simplify(b);
268 if let (Expr::Const(x), Expr::Const(y)) = (&a, &b) {
270 return cst(x * y);
271 }
272 if matches!(a, Expr::Const(x) if x == 0.0) {
274 return cst(0.0);
275 }
276 if matches!(b, Expr::Const(x) if x == 0.0) {
277 return cst(0.0);
278 }
279 if matches!(a, Expr::Const(x) if x == 1.0) {
281 return b;
282 }
283 if matches!(b, Expr::Const(x) if x == 1.0) {
285 return a;
286 }
287 Expr::Mul(Box::new(a), Box::new(b))
288 }
289
290 Expr::Pow(base, exp) => {
291 let base = simplify(base);
292 let exp = simplify(exp);
293 if let (Expr::Const(b), Expr::Const(e)) = (&base, &exp) {
295 return cst(b.powf(*e));
296 }
297 if matches!(exp, Expr::Const(e) if e == 0.0) {
299 return cst(1.0);
300 }
301 if matches!(exp, Expr::Const(e) if e == 1.0) {
303 return base;
304 }
305 Expr::Pow(Box::new(base), Box::new(exp))
306 }
307
308 Expr::Neg(inner) => {
309 let inner = simplify(inner);
310 if let Expr::Const(c) = &inner {
312 return cst(-c);
313 }
314 if let Expr::Neg(x) = inner {
316 return *x;
317 }
318 Expr::Neg(Box::new(inner))
319 }
320
321 Expr::Sin(inner) => Expr::Sin(Box::new(simplify(inner))),
322 Expr::Cos(inner) => Expr::Cos(Box::new(simplify(inner))),
323 Expr::Exp(inner) => {
324 let inner = simplify(inner);
325 if let Expr::Const(c) = &inner {
326 return cst(c.exp());
327 }
328 Expr::Exp(Box::new(inner))
329 }
330 Expr::Ln(inner) => {
331 let inner = simplify(inner);
332 if let Expr::Const(c) = &inner
333 && *c > 0.0
334 {
335 return cst(c.ln());
336 }
337 Expr::Ln(Box::new(inner))
338 }
339 }
340}
341
342pub fn to_string(expr: &Expr) -> String {
348 expr_to_str(expr)
349}
350
351fn expr_to_str(expr: &Expr) -> String {
352 match expr {
353 Expr::Const(c) => {
354 if c.fract() == 0.0 && c.abs() < 1e15 {
356 format!("{}", *c as i64)
357 } else {
358 format!("{c}")
359 }
360 }
361 Expr::Var(name) => name.clone(),
362 Expr::Add(a, b) => {
363 let bs = expr_to_str(b);
364 if let Expr::Neg(inner) = b.as_ref() {
366 format!("({} - {})", expr_to_str(a), expr_to_str(inner))
367 } else if let Some(bs_stripped) = bs.strip_prefix('-') {
368 format!("({} - {})", expr_to_str(a), bs_stripped)
369 } else {
370 format!("({} + {})", expr_to_str(a), bs)
371 }
372 }
373 Expr::Mul(a, b) => format!("({} * {})", expr_to_str(a), expr_to_str(b)),
374 Expr::Pow(base, exp) => format!("({}^{})", expr_to_str(base), expr_to_str(exp)),
375 Expr::Neg(inner) => format!("(-{})", expr_to_str(inner)),
376 Expr::Sin(inner) => format!("sin({})", expr_to_str(inner)),
377 Expr::Cos(inner) => format!("cos({})", expr_to_str(inner)),
378 Expr::Exp(inner) => format!("exp({})", expr_to_str(inner)),
379 Expr::Ln(inner) => format!("ln({})", expr_to_str(inner)),
380 }
381}
382
383#[cfg(test)]
388mod tests {
389 use super::*;
390
391 fn vars(bindings: &[(&str, f64)]) -> HashMap<String, f64> {
392 bindings.iter().map(|(k, v)| (k.to_string(), *v)).collect()
393 }
394
395 #[test]
398 fn eval_const() {
399 assert_eq!(eval(&cst(3.125), &HashMap::new()).unwrap(), 3.125);
400 }
401
402 #[test]
403 fn eval_var_found() {
404 let e = var("x");
405 assert_eq!(eval(&e, &vars(&[("x", 5.0)])).unwrap(), 5.0);
406 }
407
408 #[test]
409 fn eval_var_missing_returns_err() {
410 let e = var("y");
411 assert!(eval(&e, &HashMap::new()).is_err());
412 }
413
414 #[test]
415 fn eval_add() {
416 let e = var("x").add_expr(cst(1.0));
417 assert_eq!(eval(&e, &vars(&[("x", 4.0)])).unwrap(), 5.0);
418 }
419
420 #[test]
421 fn eval_mul() {
422 let e = var("x").mul_expr(cst(3.0));
423 assert_eq!(eval(&e, &vars(&[("x", 2.0)])).unwrap(), 6.0);
424 }
425
426 #[test]
427 fn eval_pow() {
428 let e = var("x").pow(cst(3.0));
429 assert!((eval(&e, &vars(&[("x", 2.0)])).unwrap() - 8.0).abs() < 1e-12);
430 }
431
432 #[test]
433 fn eval_neg() {
434 let e = var("x").neg();
435 assert_eq!(eval(&e, &vars(&[("x", 7.0)])).unwrap(), -7.0);
436 }
437
438 #[test]
439 fn eval_sin() {
440 let e = var("x").sin();
441 let got = eval(&e, &vars(&[("x", 0.0)])).unwrap();
442 assert!(got.abs() < 1e-12);
443 }
444
445 #[test]
446 fn eval_cos() {
447 let e = var("x").cos();
448 let got = eval(&e, &vars(&[("x", 0.0)])).unwrap();
449 assert!((got - 1.0).abs() < 1e-12);
450 }
451
452 #[test]
453 fn eval_exp() {
454 let e = var("x").exp();
455 let got = eval(&e, &vars(&[("x", 0.0)])).unwrap();
456 assert!((got - 1.0).abs() < 1e-12);
457 }
458
459 #[test]
460 fn eval_ln() {
461 let e = var("x").ln();
462 let got = eval(&e, &vars(&[("x", 1.0)])).unwrap();
463 assert!(got.abs() < 1e-12);
464 }
465
466 #[test]
467 fn eval_ln_nonpositive_returns_err() {
468 let e = var("x").ln();
469 assert!(eval(&e, &vars(&[("x", 0.0)])).is_err());
470 assert!(eval(&e, &vars(&[("x", -1.0)])).is_err());
471 }
472
473 #[test]
474 fn eval_complex_poly() {
475 let x = var("x");
477 let e = cst(3.0)
478 .mul_expr(x.clone().pow(cst(2.0)))
479 .add_expr(cst(2.0).mul_expr(x.clone()))
480 .add_expr(cst(1.0));
481 let got = eval(&e, &vars(&[("x", 2.0)])).unwrap();
482 assert!((got - 17.0).abs() < 1e-12);
483 }
484
485 #[test]
488 fn diff_const_is_zero() {
489 let e = diff(&cst(42.0), "x");
490 assert_eq!(simplify(&e), cst(0.0));
491 }
492
493 #[test]
494 fn diff_var_self_is_one() {
495 let e = diff(&var("x"), "x");
496 assert_eq!(simplify(&e), cst(1.0));
497 }
498
499 #[test]
500 fn diff_var_other_is_zero() {
501 let e = diff(&var("y"), "x");
502 assert_eq!(simplify(&e), cst(0.0));
503 }
504
505 #[test]
506 fn diff_linear() {
507 let e = cst(3.0).mul_expr(var("x"));
509 let d = simplify(&diff(&e, "x"));
510 let got = eval(&d, &HashMap::new()).unwrap();
511 assert!((got - 3.0).abs() < 1e-12);
512 }
513
514 #[test]
515 fn diff_quadratic() {
516 let e = var("x").pow(cst(2.0));
518 let d = simplify(&diff(&e, "x"));
519 let got = eval(&d, &vars(&[("x", 3.0)])).unwrap();
520 assert!((got - 6.0).abs() < 1e-12);
521 }
522
523 #[test]
524 fn diff_cubic() {
525 let e = var("x").pow(cst(3.0));
527 let d = simplify(&diff(&e, "x"));
528 let got = eval(&d, &vars(&[("x", 2.0)])).unwrap();
529 assert!((got - 12.0).abs() < 1e-12);
530 }
531
532 #[test]
533 fn diff_sin() {
534 let e = var("x").sin();
536 let d = simplify(&diff(&e, "x"));
537 let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
538 assert!((got - 1.0).abs() < 1e-12);
539 }
540
541 #[test]
542 fn diff_cos() {
543 let e = var("x").cos();
545 let d = simplify(&diff(&e, "x"));
546 let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
547 assert!(got.abs() < 1e-12);
548 }
549
550 #[test]
551 fn diff_exp() {
552 let e = var("x").exp();
554 let d = simplify(&diff(&e, "x"));
555 let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
556 assert!((got - 1.0).abs() < 1e-12);
557 }
558
559 #[test]
560 fn diff_ln() {
561 let e = var("x").ln();
563 let d = simplify(&diff(&e, "x"));
564 let got = eval(&d, &vars(&[("x", 2.0)])).unwrap();
565 assert!((got - 0.5).abs() < 1e-12);
566 }
567
568 #[test]
569 fn diff_product_rule() {
570 let e = var("x").mul_expr(var("x").sin());
572 let d = simplify(&diff(&e, "x"));
573 let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
574 assert!(got.abs() < 1e-12);
575 }
576
577 #[test]
578 fn diff_chain_sin_of_poly() {
579 let e = var("x").pow(cst(2.0)).sin();
581 let d = simplify(&diff(&e, "x"));
582 let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
583 assert!(got.abs() < 1e-12);
584 }
585
586 #[test]
587 fn diff_neg() {
588 let e = var("x").neg();
590 let d = simplify(&diff(&e, "x"));
591 let got = eval(&d, &HashMap::new()).unwrap();
592 assert!((got + 1.0).abs() < 1e-12);
593 }
594
595 #[test]
598 fn simplify_zero_plus_x() {
599 let e = cst(0.0).add_expr(var("x"));
600 assert_eq!(simplify(&e), var("x"));
601 }
602
603 #[test]
604 fn simplify_x_plus_zero() {
605 let e = var("x").add_expr(cst(0.0));
606 assert_eq!(simplify(&e), var("x"));
607 }
608
609 #[test]
610 fn simplify_zero_times_x() {
611 let e = cst(0.0).mul_expr(var("x"));
612 assert_eq!(simplify(&e), cst(0.0));
613 }
614
615 #[test]
616 fn simplify_x_times_zero() {
617 let e = var("x").mul_expr(cst(0.0));
618 assert_eq!(simplify(&e), cst(0.0));
619 }
620
621 #[test]
622 fn simplify_one_times_x() {
623 let e = cst(1.0).mul_expr(var("x"));
624 assert_eq!(simplify(&e), var("x"));
625 }
626
627 #[test]
628 fn simplify_x_times_one() {
629 let e = var("x").mul_expr(cst(1.0));
630 assert_eq!(simplify(&e), var("x"));
631 }
632
633 #[test]
634 fn simplify_x_pow_zero() {
635 let e = var("x").pow(cst(0.0));
636 assert_eq!(simplify(&e), cst(1.0));
637 }
638
639 #[test]
640 fn simplify_x_pow_one() {
641 let e = var("x").pow(cst(1.0));
642 assert_eq!(simplify(&e), var("x"));
643 }
644
645 #[test]
646 fn simplify_double_neg() {
647 let e = var("x").neg().neg();
648 assert_eq!(simplify(&e), var("x"));
649 }
650
651 #[test]
652 fn simplify_const_fold_add() {
653 let e = cst(3.0).add_expr(cst(4.0));
654 assert_eq!(simplify(&e), cst(7.0));
655 }
656
657 #[test]
658 fn simplify_const_fold_mul() {
659 let e = cst(3.0).mul_expr(cst(4.0));
660 assert_eq!(simplify(&e), cst(12.0));
661 }
662
663 #[test]
664 fn simplify_const_fold_pow() {
665 let e = cst(2.0).pow(cst(10.0));
666 assert_eq!(simplify(&e), cst(1024.0));
667 }
668
669 #[test]
672 fn to_string_const() {
673 assert_eq!(to_string(&cst(3.0)), "3");
674 }
675
676 #[test]
677 fn to_string_var() {
678 assert_eq!(to_string(&var("theta")), "theta");
679 }
680
681 #[test]
682 fn to_string_add() {
683 let e = var("x").add_expr(cst(1.0));
684 let s = to_string(&e);
685 assert!(s.contains("x") && s.contains("1") && s.contains("+"));
686 }
687
688 #[test]
689 fn to_string_mul() {
690 let e = var("a").mul_expr(var("b"));
691 let s = to_string(&e);
692 assert!(s.contains("a") && s.contains("b") && s.contains("*"));
693 }
694
695 #[test]
696 fn to_string_pow() {
697 let e = var("x").pow(cst(2.0));
698 let s = to_string(&e);
699 assert!(s.contains("x") && s.contains("2") && s.contains("^"));
700 }
701
702 #[test]
703 fn to_string_sin() {
704 let s = to_string(&var("x").sin());
705 assert!(s.starts_with("sin("));
706 }
707
708 #[test]
709 fn to_string_cos() {
710 let s = to_string(&var("x").cos());
711 assert!(s.starts_with("cos("));
712 }
713
714 #[test]
715 fn to_string_exp() {
716 let s = to_string(&var("x").exp());
717 assert!(s.starts_with("exp("));
718 }
719
720 #[test]
721 fn to_string_ln() {
722 let s = to_string(&var("x").ln());
723 assert!(s.starts_with("ln("));
724 }
725
726 #[test]
727 fn to_string_neg() {
728 let s = to_string(&var("x").neg());
729 assert!(s.contains("x") && s.contains('-'));
730 }
731
732 #[test]
735 fn diff_poly_numeric_check() {
736 let x = var("x");
738 let poly = x
739 .clone()
740 .pow(cst(4.0))
741 .sub_expr(cst(3.0).mul_expr(x.clone().pow(cst(2.0))))
742 .add_expr(cst(2.0));
743 let d = simplify(&diff(&poly, "x"));
744 let got = eval(&d, &vars(&[("x", 1.0)])).unwrap();
745 assert!((got - (-2.0)).abs() < 1e-10);
746 }
747
748 #[test]
749 fn diff_exp_of_linear() {
750 let e = cst(3.0).mul_expr(var("x")).exp();
752 let d = simplify(&diff(&e, "x"));
753 let got = eval(&d, &vars(&[("x", 0.0)])).unwrap();
754 assert!((got - 3.0).abs() < 1e-12);
755 }
756}