1use itertools::Itertools;
2
3use crate::*;
4
5use super::factor_coeff_no_div;
6
7impl std::ops::Div<&Box<dyn Expr>> for &dyn Expr {
8 type Output = Box<dyn Expr>;
9
10 fn div(self, rhs: &Box<dyn Expr>) -> Self::Output {
11 self / rhs.get_ref()
12 }
13}
14
15pub fn subs<'a, E: Expr + ?Sized>(expr: &E, substitutions: &[[Box<dyn Expr>; 2]]) -> Box<dyn Expr> {
16 for [replaced, replacement] in substitutions {
17 if expr.get_ref() == replaced {
18 return replacement.clone_box();
19 }
20
21 match (replaced.known_expr(), expr.known_expr()) {
22 (KnownExpr::Mul(_), KnownExpr::Mul(_))
23 if factor_coeff_no_div(expr.get_ref(), replaced.get_ref()).is_some() =>
24 {
25 return expr.get_ref() / replaced * replacement;
26 }
27 (KnownExpr::Add(_), KnownExpr::Add(_)) => {
28 todo!("Implement subs for addition replacement")
29 }
30
31 _ => {}
32 }
33 }
34
35 expr.from_args(
36 expr.args()
37 .into_iter()
38 .map(|arg| {
39 if let Some(expr_vec) = arg.as_any().downcast_ref::<Vec<Box<dyn Expr>>>() {
40 let res = expr_vec
41 .into_iter()
42 .map(|e| subs(e.get_ref(), substitutions))
43 .collect_vec();
44 res.clone_arg()
45 } else if let Some(expr) = arg.as_expr() {
46 expr.subs(&substitutions).into()
47 } else {
48 arg
49 }
50 })
51 .collect(),
52 )
53}
54pub fn subs_box(expr: &Box<dyn Expr>, substitutions: &[[Box<dyn Expr>; 2]]) -> Box<dyn Expr> {
55 subs(&**expr, substitutions)
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 #[test]
63 fn test_subs() {
64 let [x, y] = symbols!("x", "y");
65 let s = [[x.clone_box(), y.clone_box()]];
66
67 assert_eq!(subs(x, &s), y);
68 }
69
70 #[test]
71 fn test_subs_product() {
72 let [x, y, z] = symbols!("x", "y", "z");
73
74 let s = [[x * y, z.clone_box()]];
75
76 assert_eq!(subs_box(&(x * y * 2), &s), z * 2);
77 }
78
79 #[test]
80 fn test_subs_symbol_in_function() {
81 let x = symbol!("x");
82
83 let expr = Func::new("sin", vec![(x * 2).get_ref()]);
84
85 let s = [[x.clone_box(), Symbol::new_box("point[0]")]];
86 let res = subs(&expr, &s);
87
88 assert_eq!(res.to_cpp(), "std::sin(2 * point[0])",)
89 }
90}