symrs/expr/
mul.rs

1use super::*;
2use indexmap::IndexMap;
3
4#[derive(Clone)]
5pub struct Mul {
6    pub operands: Vec<Box<dyn Expr>>,
7}
8
9impl Expr for Mul {
10    fn known_expr(&self) -> KnownExpr {
11        KnownExpr::Mul(self)
12    }
13    fn get_ref<'a>(&'a self) -> &'a dyn Expr {
14        self as &dyn Expr
15    }
16
17    fn as_mul(&self) -> Option<&Mul> {
18        Some(self)
19    }
20    fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
21        self.operands.iter().for_each(|e| f(&**e));
22    }
23
24    fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
25        let args: Vec<Box<dyn Expr>> = args.iter().cloned().collect();
26        Box::new(Mul { operands: args })
27    }
28
29    fn clone_box(&self) -> Box<dyn Expr> {
30        Box::new(self.clone())
31    }
32
33    fn str(&self) -> String {
34        let pieces: Vec<_> = self
35            .operands
36            .iter()
37            .enumerate()
38            .map(|(i, op)| match KnownExpr::from_expr_box(op) {
39                KnownExpr::Integer(Integer { value: -1 }) if i == 0 => "-".to_string(),
40                KnownExpr::Add(_) if self.operands.len() > 1 => format!("({})", op.str()),
41                KnownExpr::Pow(pow) if self.operands.len() > 1 => format!("({})", pow.str()),
42                KnownExpr::Rational(r) if self.operands.len() > 1 => format!("({})", r.str()),
43                KnownExpr::Symbol(Symbol { name })
44                    if i < self.operands.len() - 1 && name.len() > 1 =>
45                {
46                    format!("{name}.")
47                }
48                KnownExpr::Integer(Integer { value }) => value.to_string(),
49                KnownExpr::Symbol(Symbol { name }) => name.to_string(),
50                _ if self.operands.len() > 1 => format!("({})", op.str()),
51                _ => op.str(),
52            })
53            .collect();
54        format!("{}", pieces.join(""))
55    }
56
57    fn is_number(&self) -> bool {
58        self.operands.iter().all(|op| op.is_number())
59    }
60
61    fn to_cpp(&self) -> String {
62        let mut ops = self.operands.iter().peekable();
63        let mut res = String::new();
64        // If expression starts with -1, transform it into -
65        if let Some(first_op) = ops.peek()
66            && first_op.is_neg_one()
67        {
68            ops.next();
69            res += "-";
70        }
71        ops.map(|op| match op.known_expr() {
72            KnownExpr::Add(_) if self.operands.len() > 1 => format!("({})", op.to_cpp()),
73            KnownExpr::Pow(pow) if self.operands.len() > 1 => format!("({})", pow.to_cpp()),
74            KnownExpr::Rational(r) if self.operands.len() > 1 => format!("({})", r.to_cpp()),
75            _ => op.to_cpp(),
76        })
77        .enumerate()
78        .for_each(|(i, op)| {
79            if i > 0 {
80                res += " * ";
81            }
82            res += &op
83        });
84        res
85    }
86
87    fn expand(&self) -> Box<dyn Expr> {
88        // 2 * (x + y) * (z + g) = 2xz + 2xg + 2yz + 2yg
89        // 2
90        // (x + y)
91
92        // xy -> xy
93
94        // (x + y) * z -> xz + yz
95
96        let mut res: Vec<Box<dyn Expr>> = Vec::with_capacity(self.operands.len());
97        res.push(Integer::new_box(1));
98
99        for op in &self.operands {
100            let op = op.expand();
101
102            match KnownExpr::from_expr_box(&op) {
103                KnownExpr::Add(Add { operands }) => {
104                    res = res
105                        .iter()
106                        .flat_map(|x| {
107                            operands
108                                .iter()
109                                .flat_map(|expr| match KnownExpr::from_expr_box(expr) {
110                                    KnownExpr::Add(Add { operands }) => operands.clone(),
111                                    _ => vec![expr.clone_box()],
112                                })
113                                .map(move |addendum| x * &addendum)
114                        })
115                        .collect();
116                }
117                _ => {
118                    for new_op in &mut res {
119                        *new_op *= &op;
120                    }
121                }
122            }
123        }
124
125        if res.len() == 1 {
126            res[0].clone_box()
127        } else {
128            Box::new(Add { operands: res })
129        }
130    }
131}
132
133// trait ArgIterOps {
134//     fn map_exprs(&self) ->
135// }
136//
137// impl
138
139impl Mul {
140    pub fn new_box(operands: Vec<&Box<dyn Expr>>) -> Box<dyn Expr> {
141        Box::new(Mul {
142            operands: operands.iter().copied().cloned().collect(),
143        })
144    }
145
146    pub fn new<'a, Ops: IntoIterator<Item = &'a dyn Expr>>(operands: Ops) -> Self {
147        Mul {
148            operands: operands.into_iter().map(|e| e.clone_box()).collect(),
149        }
150    }
151
152    pub fn new_move(operands: Vec<Box<dyn Expr>>) -> Self {
153        Mul { operands }
154    }
155}
156
157impl<E: Expr> std::ops::Mul<&E> for Box<dyn Expr> {
158    type Output = Box<dyn Expr>;
159
160    fn mul(self, rhs: &E) -> Self::Output {
161        &*self * rhs.get_ref()
162    }
163}
164
165impl std::ops::Mul for &Box<dyn Expr> {
166    type Output = Box<dyn Expr>;
167
168    fn mul(self, rhs: &Box<dyn Expr>) -> Self::Output {
169        &**self * &**rhs
170    }
171}
172
173impl std::ops::Mul for Box<dyn Expr> {
174    type Output = Box<dyn Expr>;
175
176    fn mul(self, rhs: Box<dyn Expr>) -> Self::Output {
177        &*self * &*rhs
178    }
179}
180
181impl std::ops::Mul<&dyn Expr> for Box<dyn Expr> {
182    type Output = Box<dyn Expr>;
183
184    fn mul(self, rhs: &dyn Expr) -> Self::Output {
185        &*self * rhs
186    }
187}
188
189impl std::ops::Mul<&Box<dyn Expr>> for Box<dyn Expr> {
190    type Output = Box<dyn Expr>;
191
192    fn mul(self, rhs: &Box<dyn Expr>) -> Self::Output {
193        &*self * &**rhs
194    }
195}
196
197impl std::ops::Mul<&dyn Expr> for &Box<dyn Expr> {
198    type Output = Box<dyn Expr>;
199
200    fn mul(self, rhs: &dyn Expr) -> Self::Output {
201        &**self * rhs
202    }
203}
204
205impl std::ops::Mul<isize> for Box<dyn Expr> {
206    type Output = Box<dyn Expr>;
207
208    fn mul(self, rhs: isize) -> Self::Output {
209        Integer::new_box(rhs) * &*self
210    }
211}
212
213impl std::ops::Add for Mul {
214    type Output = Add;
215
216    fn add(self, rhs: Self) -> Self::Output {
217        Add::new([&self as &dyn Expr, &rhs as &dyn Expr])
218    }
219}
220
221impl std::ops::MulAssign<&dyn Expr> for Box<dyn Expr> {
222    fn mul_assign(&mut self, rhs: &dyn Expr) {
223        *self = &**self * rhs;
224    }
225}
226
227impl std::ops::MulAssign<&Box<dyn Expr>> for Box<dyn Expr> {
228    fn mul_assign(&mut self, rhs: &Box<dyn Expr>) {
229        *self *= &**rhs;
230    }
231}
232
233impl std::ops::MulAssign for Box<dyn Expr> {
234    fn mul_assign(&mut self, rhs: Box<dyn Expr>) {
235        *self *= &*rhs;
236    }
237}
238
239impl std::ops::Mul for &dyn Expr {
240    type Output = Box<dyn Expr>;
241
242    fn mul(self, rhs: Self) -> Self::Output {
243        if self.is_zero() || rhs.is_zero() {
244            return Integer::new_box(0);
245        }
246        if self.is_one() {
247            return rhs.clone_box();
248        }
249        if rhs.is_one() {
250            return self.clone_box();
251        }
252
253        match (self.known_expr(), rhs.known_expr()) {
254            (KnownExpr::Rational(a), KnownExpr::Rational(b)) => return Box::new(*a * *b),
255            (KnownExpr::Integer(a), KnownExpr::Integer(b)) => {
256                return Integer::new_box(a.value * b.value);
257            }
258            (KnownExpr::Integer(a), KnownExpr::Rational(b)) => return Box::new(*b * a),
259            (KnownExpr::Rational(a), KnownExpr::Integer(b)) => return Box::new(*a * b),
260            (KnownExpr::Pow(a), KnownExpr::Pow(b))
261                if a.base().is_number()
262                    && b.base().is_number()
263                    && b.exponent().is_number()
264                    && a.exponent() == b.exponent() =>
265            {
266                return (a.base() * b.base()).pow(&a.exponent().clone_box());
267            }
268            _ => (),
269        }
270
271        let (coeff_a, lhs) = self.get_coeff();
272        let (coeff_b, rhs) = rhs.get_coeff();
273
274        let coeff = (coeff_a) * coeff_b;
275        let mut new_operands: Vec<&Box<dyn Expr>> = Vec::new();
276
277        match (
278            KnownExpr::from_expr_box(&lhs),
279            KnownExpr::from_expr_box(&rhs),
280        ) {
281            (KnownExpr::Mul(Mul { operands: a }), KnownExpr::Mul(Mul { operands: b })) => {
282                a.iter()
283                    .chain(b.iter())
284                    .for_each(|op| new_operands.push(&*op));
285            }
286            (_, KnownExpr::Mul(Mul { operands })) => {
287                if !lhs.is_one() {
288                    new_operands.push(&lhs);
289                }
290                operands.iter().for_each(|op| new_operands.push(&*op));
291            }
292            (KnownExpr::Mul(Mul { operands }), _) => {
293                operands.iter().for_each(|op| new_operands.push(&*op));
294                if !rhs.is_one() {
295                    new_operands.push(&rhs);
296                }
297            }
298
299            _ => {
300                if !lhs.is_one() {
301                    new_operands.push(&lhs);
302                }
303                if !rhs.is_one() {
304                    new_operands.push(&rhs);
305                }
306            }
307        }
308        let coeff = coeff.simplify();
309        if !coeff.is_one() {
310            new_operands.insert(0, &coeff);
311        }
312
313        let mut operands_exponents: IndexMap<Box<dyn Expr>, Box<dyn Expr>> = IndexMap::new();
314
315        for op in new_operands
316            .iter()
317            // Split up factors fo multiplication and powers
318            .flat_map(|op| match op.known_expr() {
319                KnownExpr::Mul(Mul { operands }) => operands.clone(),
320                KnownExpr::Pow(Pow { base, exponent })
321                    if matches!(base.known_expr(), KnownExpr::Mul(Mul { .. })) =>
322                {
323                    let mul = base.as_mul().unwrap();
324                    mul.operands.iter().map(|op| op.pow(exponent)).collect()
325                }
326                _ => vec![op.clone_box()],
327            })
328        {
329            let (expr, exponent) = op.get_exponent();
330            let entry = operands_exponents
331                .entry(expr)
332                .or_insert(Integer::zero_box());
333            *entry += exponent;
334        }
335        let mut new_operands = Vec::with_capacity(operands_exponents.len());
336
337        for (expr, exponent) in operands_exponents {
338            if exponent.is_zero() {
339                continue;
340            }
341
342            if exponent.is_one() {
343                new_operands.push(expr);
344            } else {
345                new_operands.push(Box::new(Pow {
346                    base: expr,
347                    exponent,
348                }));
349            }
350        }
351
352        if new_operands.len() == 0 {
353            return Integer::one_box();
354        }
355
356        if new_operands.len() == 1 {
357            return new_operands[0].clone_box();
358        }
359
360        Box::new(Mul {
361            operands: new_operands,
362        })
363    }
364}
365
366impl std::ops::Mul<Box<dyn Expr>> for &dyn Expr {
367    type Output = Box<dyn Expr>;
368
369    fn mul(self, rhs: Box<dyn Expr>) -> Self::Output {
370        self * &*rhs
371    }
372}
373
374impl std::ops::Div for &dyn Expr {
375    type Output = Box<dyn Expr>;
376
377    fn div(self, rhs: Self) -> Self::Output {
378        self * rhs.ipow(-1)
379    }
380}
381impl std::ops::Div<&dyn Expr> for Box<dyn Expr> {
382    type Output = Box<dyn Expr>;
383
384    fn div(self, rhs: &dyn Expr) -> Self::Output {
385        &*self / rhs
386    }
387}
388
389impl<E: Expr> std::ops::Div<&E> for Box<dyn Expr> {
390    type Output = Box<dyn Expr>;
391
392    fn div(self, rhs: &E) -> Self::Output {
393        &*self / rhs.get_ref()
394    }
395}
396
397impl std::ops::Div for Box<dyn Expr> {
398    type Output = Box<dyn Expr>;
399
400    fn div(self, rhs: Box<dyn Expr>) -> Self::Output {
401        &*self / &*rhs
402    }
403}
404
405impl std::ops::Div<&Box<dyn Expr>> for Box<dyn Expr> {
406    type Output = Box<dyn Expr>;
407
408    fn div(self, rhs: &Box<dyn Expr>) -> Self::Output {
409        &*self / &**rhs
410    }
411}
412
413impl std::ops::DivAssign<&dyn Expr> for Box<dyn Expr> {
414    fn div_assign(&mut self, rhs: &dyn Expr) {
415        *self = &**self / rhs
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use crate::{symbol, symbols};
422
423    use super::*;
424
425    #[test]
426    fn test_srepr() {
427        let a = Symbol::new_box("a");
428        let b = Symbol::new_box("b");
429        let c = Symbol::new_box("c");
430        let d = Symbol::new_box("d");
431
432        let expr = a * b * c * d;
433        let expected = "Mul(Symbol(a), Symbol(b), Symbol(c), Symbol(d))";
434
435        assert_eq!(expr.srepr(), expected);
436    }
437
438    #[test]
439    fn test_srepr_advanced() {
440        let c = Symbol::new_box("c");
441        let u = Symbol::new_box("u");
442        let laplacian = Symbol::new_box("laplacian");
443
444        let expr = -c.ipow(2) * laplacian * u;
445        let expected = "Mul(Integer(-1), Pow(Symbol(c), Integer(2)), Symbol(Δ), Symbol(u))";
446
447        assert_eq!(expr.srepr(), expected);
448    }
449
450    #[test]
451    fn test_srepr_difficult() {
452        let c = &Symbol::new_box("c");
453        let u = &Symbol::new_box("u");
454        let t = &Symbol::new_box("t");
455        let laplacian = &Symbol::new_box("laplacian");
456
457        let expr = &(Diff::new(u, vec![t, t]) - c.ipow(2) * laplacian * u);
458        let expected = "Add(Diff(Symbol(u), ((Symbol(t), 2))), Mul(Integer(-1), Pow(Symbol(c), Integer(2)), Symbol(Δ), Symbol(u)))";
459
460        assert_eq!(expr.srepr(), expected);
461    }
462
463    #[test]
464    fn test_div() {
465        let a = Symbol::new_box("a");
466        let b = Symbol::new_box("b");
467        let c = Symbol::new_box("c");
468        let expr = (a - b) / c;
469        assert_eq!(
470            expr.srepr(),
471            "Mul(Add(Symbol(a), Mul(Integer(-1), Symbol(b))), Pow(Symbol(c), Integer(-1)))"
472        );
473    }
474
475    #[test]
476    fn test_div_of_product_simplifies() {
477        let [a, b, c] = symbols!("a", "b", "c");
478
479        assert_eq!(&(a * b / (a * c)), &(b / c));
480    }
481
482    #[test]
483    #[ignore]
484    fn test_simplify_frac_mul() {
485        let expr = Mul::new_move(vec![Rational::new_box(1, 2), Rational::new_box(1, 2)]);
486
487        assert_eq!(expr.simplify().srepr(), "")
488    }
489
490    #[test]
491    fn test_weird_issue() {
492        let a = symbol!("a");
493        let expr = (a - Integer::new(1).get_ref()) * a;
494        let expr = expr.subs(&[[a.clone_box(), Rational::new_box(1, 2)]]);
495
496        assert_eq!(&expr.expand().simplify(), &Rational::new_box(-1, 4))
497    }
498}