symrs/expr/ops/
subs.rs

1use itertools::Itertools;
2
3use crate::*;
4
5use super::factor_coeff_no_div;
6
7impl std::ops::Div<&Box<dyn Expr>> for &dyn Expr {
8    type Output = Box<dyn Expr>;
9
10    fn div(self, rhs: &Box<dyn Expr>) -> Self::Output {
11        self / rhs.get_ref()
12    }
13}
14
15pub fn subs<'a, E: Expr + ?Sized>(expr: &E, substitutions: &[[Box<dyn Expr>; 2]]) -> Box<dyn Expr> {
16    for [replaced, replacement] in substitutions {
17        if expr.get_ref() == replaced {
18            return replacement.clone_box();
19        }
20
21        match (replaced.known_expr(), expr.known_expr()) {
22            (KnownExpr::Mul(_), KnownExpr::Mul(_))
23                if factor_coeff_no_div(expr.get_ref(), replaced.get_ref()).is_some() =>
24            {
25                return expr.get_ref() / replaced * replacement;
26            }
27            (KnownExpr::Add(_), KnownExpr::Add(_)) => {
28                todo!("Implement subs for addition replacement")
29            }
30
31            _ => {}
32        }
33    }
34
35    expr.from_args(
36        expr.args()
37            .into_iter()
38            .map(|arg| {
39                if let Some(expr_vec) = arg.as_any().downcast_ref::<Vec<Box<dyn Expr>>>() {
40                    let res = expr_vec
41                        .into_iter()
42                        .map(|e| subs(e.get_ref(), substitutions))
43                        .collect_vec();
44                    res.clone_arg()
45                } else if let Some(expr) = arg.as_expr() {
46                    expr.subs(&substitutions).into()
47                } else {
48                    arg
49                }
50            })
51            .collect(),
52    )
53}
54pub fn subs_box(expr: &Box<dyn Expr>, substitutions: &[[Box<dyn Expr>; 2]]) -> Box<dyn Expr> {
55    subs(&**expr, substitutions)
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[test]
63    fn test_subs() {
64        let [x, y] = symbols!("x", "y");
65        let s = [[x.clone_box(), y.clone_box()]];
66
67        assert_eq!(subs(x, &s), y);
68    }
69
70    #[test]
71    fn test_subs_product() {
72        let [x, y, z] = symbols!("x", "y", "z");
73
74        let s = [[x * y, z.clone_box()]];
75
76        assert_eq!(subs_box(&(x * y * 2), &s), z * 2);
77    }
78
79    #[test]
80    fn test_subs_symbol_in_function() {
81        let x = symbol!("x");
82
83        let expr = Func::new("sin", vec![(x * 2).get_ref()]);
84
85        let s = [[x.clone_box(), Symbol::new_box("point[0]")]];
86        let res = subs(&expr, &s);
87
88        assert_eq!(res.to_cpp(), "std::sin(2 * point[0])",)
89    }
90}