symrs/expr/
diff.rs

1use indexmap::IndexMap;
2use itertools::Itertools;
3
4use super::*;
5#[derive(Clone)]
6pub struct Diff {
7    pub f: Box<dyn Expr>,
8    pub vars: IndexMap<Symbol, usize>,
9}
10
11pub trait IntoVarOrder {
12    fn into_var_order(self) -> (Symbol, usize);
13}
14
15impl Diff {
16    pub fn new<'a, It: IntoIterator<Item = &'a Box<dyn Expr>>>(
17        f: &Box<dyn Expr>,
18        vars: It,
19    ) -> Box<dyn Expr> {
20        Box::new(Diff::new_move(
21            f.clone(),
22            vars.into_iter().map(|v| v.as_symbol().unwrap()),
23        ))
24    }
25
26    pub fn new_v2(f: Box<dyn Expr>, vars: IndexMap<Symbol, usize>) -> Diff {
27        Diff { f, vars }
28    }
29
30    pub fn idiff(f: Box<dyn Expr>, var: Symbol, order: usize) -> Self {
31        let mut vars = IndexMap::new();
32        vars.insert(var, order);
33        Self { f, vars }
34    }
35
36    pub fn new_move<I, T>(f: Box<dyn Expr>, vars: I) -> Diff
37    where
38        I: IntoIterator<Item = T>,
39        T: IntoVarOrder,
40    {
41        let mut vars_orders = IndexMap::new();
42        for var in vars {
43            let (var, order) = var.into_var_order();
44            let entry = vars_orders.entry(var).or_insert(0);
45            *entry += order;
46        }
47        Diff {
48            f,
49            vars: vars_orders,
50        }
51    }
52}
53
54impl IntoVarOrder for (Symbol, usize) {
55    fn into_var_order(self) -> (Symbol, usize) {
56        self
57    }
58}
59
60impl IntoVarOrder for Symbol {
61    fn into_var_order(self) -> (Symbol, usize) {
62        (self, 1)
63    }
64}
65
66impl IntoVarOrder for &str {
67    fn into_var_order(self) -> (Symbol, usize) {
68        (Symbol::new(self), 1)
69    }
70}
71impl IntoVarOrder for char {
72    fn into_var_order(self) -> (Symbol, usize) {
73        (Symbol::new(&self.to_string()), 1)
74    }
75}
76
77impl<T: ToString> IntoVarOrder for &(T, usize) {
78    fn into_var_order(self) -> (Symbol, usize) {
79        (Symbol::new(&self.0.to_string()), self.1)
80    }
81}
82
83impl Expr for Diff {
84    fn get_ref<'a>(&'a self) -> &'a dyn Expr {
85        self as &dyn Expr
86    }
87    fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
88        f(&*self.f);
89        f(&self
90            .vars
91            .iter()
92            .map(|(var, order)| (var.clone(), *order))
93            .collect::<Vec<(Symbol, usize)>>());
94    }
95
96    fn known_expr(&self) -> KnownExpr {
97        KnownExpr::Diff(self)
98    }
99
100    fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
101        let vars = &*args[1];
102        let vars = vars as &dyn Any;
103        let vars = vars.downcast_ref::<Vec<(Symbol, usize)>>().unwrap();
104        Box::new(Diff::new_v2(
105            args[0].clone().into(),
106            IndexMap::from_iter(vars.clone()),
107        ))
108    }
109
110    fn clone_box(&self) -> Box<dyn Expr> {
111        Box::new(self.clone())
112    }
113
114    fn str(&self) -> String {
115        let order = self.vars.values().sum::<usize>();
116        let exponent = if order > 1 {
117            format!("^{}", order)
118        } else {
119            "".to_string()
120        };
121        let mut f = self.f.str();
122        if f.len() > 1 {
123            f = format!("({})", f);
124        }
125
126        let denom = self
127            .vars
128            .iter()
129            .map(|(var, order)| {
130                if *order == 1 {
131                    format!("∂{var}")
132                } else {
133                    format!("∂{}^{}", var.str(), order)
134                }
135            })
136            .join(".");
137
138        format!("∂{}{f} / {denom}", exponent,)
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use crate::symbols;
145
146    use super::*;
147
148    #[test]
149    fn test_str() {
150        let [u] = symbols!("u");
151        let t = Symbol::new("t");
152        let expr = Diff::new_move(u.ipow(2), vec![t; 2]);
153        assert_eq!(expr.str(), "∂^2(u^2) / ∂t^2")
154    }
155
156    #[test]
157    fn test_str_first_order() {
158        let [u] = symbols!("u");
159        let t = Symbol::new("t");
160        let expr = Diff::new_move(u.ipow(2), vec![t; 1]);
161        assert_eq!(expr.str(), "∂(u^2) / ∂t")
162    }
163
164    #[test]
165    fn test_str_symbol_first_order() {
166        let [u] = symbols!("u");
167        let t = Symbol::new("t");
168        let expr = Diff::new_move(u.clone_box(), vec![t; 1]);
169        assert_eq!(expr.str(), "∂u / ∂t")
170    }
171
172    #[test]
173    fn test_args() {
174        let [u] = symbols!("u");
175        let t = Symbol::new("t");
176        let expr = Diff::new_move(u.ipow(2), vec![t; 2]);
177
178        let args = expr.args();
179        assert_eq!(expr.from_args(args).srepr(), expr.srepr());
180    }
181}