symrs/expr/
add.rs

1use std::collections::HashMap;
2
3use crate::symbol;
4use indexmap::IndexMap;
5use itertools::Itertools;
6
7use super::*;
8
9#[derive(Clone)]
10pub struct Add {
11    pub operands: Vec<Box<dyn Expr>>,
12}
13
14impl Add {
15    pub fn new_box(operands: Vec<&Box<dyn Expr>>) -> Box<dyn Expr> {
16        Box::new(Add {
17            operands: operands.iter().copied().cloned().collect(),
18        })
19    }
20
21    pub fn new_box_v2(operands: Vec<Box<dyn Expr>>) -> Box<dyn Expr> {
22        if operands.len() == 0 {
23            Integer::new_box(0)
24        } else if operands.len() == 1 {
25            operands[0].clone_box()
26        } else {
27            Box::new(Add { operands })
28        }
29    }
30
31    pub fn new<'a, Ops: IntoIterator<Item = &'a dyn Expr>>(operands: Ops) -> Self {
32        Add {
33            operands: operands.into_iter().map(|e| e.clone_box()).collect(),
34        }
35    }
36
37    pub fn new_v2(ops: Vec<Box<dyn Expr>>) -> Self {
38        Add { operands: ops }
39    }
40
41    pub fn term_coeffs(&self) -> IndexMap<Box<dyn Expr>, Rational> {
42        let mut term_coeffs: IndexMap<Box<dyn Expr>, Rational> = IndexMap::new();
43
44        for op in self.operands.iter() {
45            // let op = op.expand();
46            let (coeff, expr) = op.get_coeff();
47
48            let entry = term_coeffs.entry(expr).or_insert(Rational::zero());
49            *entry += coeff;
50        }
51        term_coeffs
52            .into_iter()
53            .filter(|(_, v)| !v.is_zero())
54            .collect()
55    }
56}
57
58impl From<Vec<&Box<dyn Expr>>> for Add {
59    fn from(value: Vec<&Box<dyn Expr>>) -> Self {
60        Add::new(value.iter().map(|x| &***x).collect::<Vec<_>>())
61    }
62}
63
64impl Expr for Add {
65    fn known_expr(&self) -> KnownExpr {
66        KnownExpr::Add(self)
67    }
68    fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
69        self.operands.iter().for_each(|e| f(&**e));
70    }
71
72    fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
73        let args: Vec<Box<dyn Expr>> = args.iter().cloned().collect();
74        Box::new(Add { operands: args })
75    }
76
77    fn clone_box(&self) -> Box<dyn Expr> {
78        Box::new(self.clone())
79    }
80
81    fn str(&self) -> String {
82        let pieces: Vec<_> = self
83            .operands
84            .iter()
85            .enumerate()
86            .map(|(i, op)| match KnownExpr::from_expr_box(op) {
87                KnownExpr::Mul(Mul { operands }) if operands.len() > 0 && i > 0 => {
88                    match KnownExpr::from_expr_box(&operands[0]) {
89                        KnownExpr::Integer(Integer { value: -1 }) => {
90                            let mul = op.str();
91                            format!(" - {}", mul[1..].to_string())
92                        }
93                        // KnownExpr::Rational(r) if *r < 0 => {
94                        //     let mul = op.str();
95                        //     let len = mul.len();
96                        //     format!(" - {}", mul[if len > 1 { 2} else {0}..if len > 1 { -1} else {}].to_string())
97                        //
98                        // }
99                        _ => format!(" + {}", op.str()),
100                    }
101                }
102
103                KnownExpr::Integer(integer) if integer.value < 0 => {
104                    format!(" - {}", op.str()[1..].to_string())
105                }
106
107                _ if i > 0 => format!(" + {}", op.str()),
108                _ => op.str(),
109            })
110            .collect();
111        format!("{}", pieces.join(""))
112    }
113
114    // Same as str(&self) but calls to_cpp() on each operand
115    fn to_cpp(&self) -> String {
116        let pieces: Vec<_> = self
117            .operands
118            .iter()
119            .enumerate()
120            .map(|(i, op)| match KnownExpr::from_expr_box(op) {
121                KnownExpr::Mul(Mul { operands }) if operands.len() > 0 && i > 0 => {
122                    match KnownExpr::from_expr_box(&operands[0]) {
123                        KnownExpr::Integer(Integer { value: -1 }) => {
124                            let mul = op.to_cpp();
125                            format!(" - {}", mul[1..].to_string())
126                        }
127                        _ => format!(" + {}", op.to_cpp()),
128                    }
129                }
130                _ if i > 0 => format!(" + {}", op.to_cpp()),
131                _ => op.to_cpp(),
132            })
133            .collect();
134        format!("{}", pieces.join(""))
135    }
136
137    fn simplify(&self) -> Box<dyn Expr> {
138        if self.operands.len() == 2 {
139            match (self.operands[0].known_expr(), self.operands[1].known_expr()) {
140                (KnownExpr::Integer(a), KnownExpr::Integer(b)) => {
141                    return Integer::new_box(a.value + b.value);
142                }
143                (KnownExpr::Rational(r1), KnownExpr::Rational(r2)) => return (r1 + r2).simplify(),
144                (KnownExpr::Integer(a), KnownExpr::Rational(r2)) => return a + r2,
145                (KnownExpr::Rational(r1), KnownExpr::Integer(b)) => return r1 + b,
146                _ => (),
147            }
148        }
149        return self.from_args(
150            self.args()
151                .iter()
152                .map(|a| a.map_expr(&|e| e.simplify()))
153                .collect(),
154        );
155        // For some reason the code below isn't equivalent
156        // self.from_args(self.args_map_exprs(&|expr| match expr.known_expr() {
157        //     _ => expr.simplify(),
158        // }))
159    }
160
161    fn simplify_with_dimension(&self, dim: usize) -> Box<dyn Expr> {
162        let expr = self;
163
164        match expr.known_expr() {
165            KnownExpr::Add(Add { operands }) => {
166                let operands = operands
167                    .iter()
168                    .map(|op| op.simplify_with_dimension(dim))
169                    .collect_vec();
170                // let add = Add::new_v2(operands);
171                // let term_coeffs = add.term_coeffs();
172                let mut snd_ord_spatial_derivatives: HashMap<&Box<dyn Expr>, Vec<usize>> =
173                    HashMap::new();
174                let mut res_ops: Vec<Box<dyn Expr>> = Vec::with_capacity(operands.len());
175
176                for op in &operands {
177                    match op.known_expr() {
178                        KnownExpr::Diff(Diff { f, vars }) => {
179                            let entry = snd_ord_spatial_derivatives.entry(f).or_insert(vec![0; 3]);
180
181                            if vars.len() == 1 {
182                                let (var, order) = vars.iter().next().unwrap();
183                                if *order != 2 {
184                                    res_ops.push(op.clone());
185                                    continue;
186                                }
187                                match var.name.as_str() {
188                                    "x" => entry[0] += 1,
189                                    "y" => entry[1] += 1,
190                                    "z" => entry[2] += 1,
191                                    _ => res_ops.push(op.clone()),
192                                }
193                            } else {
194                                res_ops.push(op.clone());
195                            }
196                        }
197
198                        _ => res_ops.push(op.clone()),
199                    }
200                }
201
202                let laplacian = symbol!("laplacian");
203
204                for (f, mut counts) in snd_ord_spatial_derivatives {
205                    let min = *counts[0..dim].iter().min().unwrap();
206
207                    if min == 1 {
208                        res_ops.push(laplacian * f.get_ref());
209                    } else if min >= 1 {
210                        res_ops.push(laplacian * min * f.get_ref());
211                    }
212                    for k in 0..dim {
213                        counts[k] -= min;
214                    }
215                    for k in 0..dim {
216                        let count = counts[k];
217                        if count > 0 {
218                            todo!()
219                        }
220                    }
221                }
222
223                Add::new_box_v2(res_ops).simplify()
224            }
225            _ => expr.simplify_with_dimension(dim),
226        }
227    }
228
229    fn expand(&self) -> Box<dyn Expr> {
230        let operands: Vec<Box<dyn Expr>> = self
231            .operands
232            .iter()
233            .flat_map(|op| {
234                let op = op.expand();
235
236                match KnownExpr::from_expr_box(&op) {
237                    KnownExpr::Add(Add { operands }) => operands.clone(),
238                    _ => vec![op.clone_box()],
239                }
240            })
241            .collect();
242
243        if operands.len() == 0 {
244            Integer::new_box(0)
245        } else if operands.len() == 1 {
246            operands[0].clone()
247        } else {
248            Box::new(Add { operands })
249        }
250    }
251
252    fn get_ref<'a>(&'a self) -> &'a dyn Expr {
253        self as &dyn Expr
254    }
255
256    fn terms<'a>(&'a self) -> Box<dyn Iterator<Item = &'a dyn Expr> + 'a> {
257        Box::new(self.operands.iter().map(|o| &**o))
258    }
259}
260
261impl std::fmt::Debug for Add {
262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263        write!(f, "{}", self.srepr())
264    }
265}
266
267impl std::ops::Add for &dyn Expr {
268    type Output = Box<dyn Expr>;
269
270    fn add(self, rhs: Self) -> Self::Output {
271        if self.is_zero() {
272            return rhs.clone_box();
273        }
274        if rhs.is_zero() {
275            return self.clone_box();
276        }
277        if self == &*-rhs {
278            return Integer::new_box(0);
279        }
280
281        let mut term_coeffs: IndexMap<Box<dyn Expr>, Rational> = IndexMap::new();
282
283        match (KnownExpr::from_expr(self), KnownExpr::from_expr(rhs)) {
284            (
285                KnownExpr::Integer(Integer { value: a }),
286                KnownExpr::Integer(Integer { value: b }),
287            ) => return Integer::new_box(a + b),
288            (KnownExpr::Rational(r1), KnownExpr::Rational(r2)) => return Box::new(r1 + r2),
289            (KnownExpr::Add(Add { operands: ops_a }), KnownExpr::Add(Add { operands: ops_b })) => {
290                ops_a
291                    .iter()
292                    .chain(ops_b.iter())
293                    .filter(|x| !x.is_zero())
294                    .for_each(|op| {
295                        let (coeff, expr) = op.get_coeff();
296                        let entry = term_coeffs.entry(expr.clone()).or_insert(Rational::zero());
297                        *entry += coeff;
298                    });
299            }
300            (KnownExpr::Add(Add { operands }), _) => {
301                operands
302                    .iter()
303                    .map(|e| e.get_ref())
304                    .chain(iter::once(rhs))
305                    .filter(|x| !x.is_zero())
306                    .for_each(|op| {
307                        let (coeff, expr) = op.get_coeff();
308                        let entry = term_coeffs.entry(expr.clone()).or_insert(Rational::zero());
309                        *entry += coeff;
310                    });
311            }
312            (_, KnownExpr::Add(Add { operands })) => {
313                iter::once(self)
314                    .chain(operands.iter().map(|e| e.get_ref()))
315                    .filter(|x| !x.is_zero())
316                    .for_each(|op| {
317                        let (coeff, expr) = op.get_coeff();
318                        let entry = term_coeffs.entry(expr.clone()).or_insert(Rational::zero());
319                        *entry += coeff;
320                    });
321            }
322            _ => {
323                iter::once(self)
324                    .chain(iter::once(rhs))
325                    .filter(|x| !x.is_zero())
326                    .for_each(|op| {
327                        let (coeff, expr) = op.get_coeff();
328                        let entry = term_coeffs.entry(expr.clone()).or_insert(Rational::zero());
329                        *entry += coeff;
330                    });
331            }
332        };
333
334        let mut operands: Vec<Box<dyn Expr>> = Vec::with_capacity(term_coeffs.len());
335
336        for (expr, coeff) in term_coeffs {
337            if coeff.is_zero() {
338                continue;
339            }
340
341            if !coeff.is_one() {
342                operands.push(coeff.simplify() * expr);
343            } else {
344                operands.push(expr)
345            }
346        }
347
348        if operands.len() == 0 {
349            Integer::new_box(0)
350        } else if operands.len() == 1 {
351            operands[0].clone_box()
352        } else {
353            Box::new(Add { operands })
354        }
355    }
356}
357impl std::ops::Add for &Box<dyn Expr> {
358    type Output = Box<dyn Expr>;
359
360    fn add(self, rhs: &Box<dyn Expr>) -> Self::Output {
361        &**self + &**rhs
362    }
363}
364
365impl std::ops::Add<Box<dyn Expr>> for &dyn Expr {
366    type Output = Box<dyn Expr>;
367
368    fn add(self, rhs: Box<dyn Expr>) -> Self::Output {
369        self + &*rhs
370    }
371}
372
373impl std::ops::Add for Box<dyn Expr> {
374    type Output = Box<dyn Expr>;
375
376    fn add(self, rhs: Box<dyn Expr>) -> Self::Output {
377        &*self + &*rhs
378    }
379}
380
381impl std::ops::AddAssign for Box<dyn Expr> {
382    fn add_assign(&mut self, rhs: Self) {
383        *self = self.get_ref() + rhs.get_ref();
384    }
385}
386impl<'a> From<&'a Add> for &'a dyn Expr {
387    fn from(value: &'a Add) -> Self {
388        value as &'a dyn Expr
389    }
390}
391
392impl std::ops::Mul<&dyn Expr> for Add {
393    type Output = Mul;
394
395    fn mul(self, rhs: &dyn Expr) -> Self::Output {
396        Mul::new([&self as &dyn Expr, rhs])
397    }
398}
399
400impl std::ops::Sub for &dyn Expr {
401    type Output = Box<dyn Expr>;
402
403    fn sub(self, rhs: Self) -> Self::Output {
404        self + &*(-rhs)
405    }
406}
407impl std::ops::Sub for &Box<dyn Expr> {
408    type Output = Box<dyn Expr>;
409
410    fn sub(self, rhs: &Box<dyn Expr>) -> Self::Output {
411        &**self - &**rhs
412    }
413}
414
415impl std::ops::Sub<Box<dyn Expr>> for &Box<dyn Expr> {
416    type Output = Box<dyn Expr>;
417
418    fn sub(self, rhs: Box<dyn Expr>) -> Self::Output {
419        &**self - &*rhs
420    }
421}
422
423impl std::ops::Sub<&Box<dyn Expr>> for Box<dyn Expr> {
424    type Output = Box<dyn Expr>;
425
426    fn sub(self, rhs: &Box<dyn Expr>) -> Self::Output {
427        &*self - &**rhs
428    }
429}
430
431impl std::ops::Sub for Box<dyn Expr> {
432    type Output = Box<dyn Expr>;
433
434    fn sub(self, rhs: Box<dyn Expr>) -> Self::Output {
435        &*self - &*rhs
436    }
437}
438
439impl std::ops::SubAssign<&dyn Expr> for Box<dyn Expr> {
440    fn sub_assign(&mut self, rhs: &dyn Expr) {
441        *self = &**self - rhs;
442    }
443}
444
445impl std::ops::Add<&dyn Expr> for Box<dyn Expr> {
446    type Output = Box<dyn Expr>;
447
448    fn add(self, rhs: &dyn Expr) -> Self::Output {
449        &*self + rhs
450    }
451}
452
453impl std::ops::Add<isize> for Box<dyn Expr> {
454    type Output = Box<dyn Expr>;
455
456    fn add(self, rhs: isize) -> Self::Output {
457        &*self + Integer::new_box(rhs)
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    #[test]
465    fn test_srepr() {
466        let a = Symbol::new_box("a");
467        let b = Symbol::new_box("b");
468
469        let expr = a - b;
470        let expected = "Add(Symbol(a), Mul(Integer(-1), Symbol(b)))";
471
472        assert_eq!(expr.srepr(), expected);
473    }
474
475    #[test]
476    fn test_srepr_2() {
477        let a = Symbol::new_box("a");
478        let b = Symbol::new_box("b");
479        let c = Symbol::new_box("c");
480
481        let expr = a - b * c;
482        let expected = "Add(Symbol(a), Mul(Integer(-1), Symbol(b), Symbol(c)))";
483
484        assert_eq!(expr.srepr(), expected);
485    }
486
487    #[test]
488    fn test_simplify_dimension() {
489        let expr: Box<dyn Expr> = "d2u/dx2 + d2u/dy2".parse().unwrap();
490        let expected: Box<dyn Expr> = "laplacian * u".parse().unwrap();
491
492        assert_eq!(expr.simplify_with_dimension(2), expected);
493    }
494
495    #[test]
496    #[ignore]
497    fn test_simplify_dim_advanced_add() {
498        let expr: Box<dyn Expr> = "c^2 * (∂^2u / ∂x^2 + ∂^2u / ∂y^2) + source"
499            .parse()
500            .unwrap();
501        let expected: Box<dyn Expr> = "c^2 * laplacian * u + source".parse().unwrap();
502
503        assert_eq!(expr.simplify_with_dimension(2), expected);
504    }
505}