1use super::*;
2use indexmap::IndexMap;
3
4#[derive(Clone)]
5pub struct Mul {
6 pub operands: Vec<Box<dyn Expr>>,
7}
8
9impl Expr for Mul {
10 fn known_expr(&self) -> KnownExpr {
11 KnownExpr::Mul(self)
12 }
13 fn get_ref<'a>(&'a self) -> &'a dyn Expr {
14 self as &dyn Expr
15 }
16
17 fn as_mul(&self) -> Option<&Mul> {
18 Some(self)
19 }
20 fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
21 self.operands.iter().for_each(|e| f(&**e));
22 }
23
24 fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
25 let args: Vec<Box<dyn Expr>> = args.iter().cloned().collect();
26 Box::new(Mul { operands: args })
27 }
28
29 fn clone_box(&self) -> Box<dyn Expr> {
30 Box::new(self.clone())
31 }
32
33 fn str(&self) -> String {
34 let pieces: Vec<_> = self
35 .operands
36 .iter()
37 .enumerate()
38 .map(|(i, op)| match KnownExpr::from_expr_box(op) {
39 KnownExpr::Integer(Integer { value: -1 }) if i == 0 => "-".to_string(),
40 KnownExpr::Add(_) if self.operands.len() > 1 => format!("({})", op.str()),
41 KnownExpr::Pow(pow) if self.operands.len() > 1 => format!("({})", pow.str()),
42 KnownExpr::Rational(r) if self.operands.len() > 1 => format!("({})", r.str()),
43 KnownExpr::Symbol(Symbol { name })
44 if i < self.operands.len() - 1 && name.len() > 1 =>
45 {
46 format!("{name}.")
47 }
48 KnownExpr::Integer(Integer { value }) => value.to_string(),
49 KnownExpr::Symbol(Symbol { name }) => name.to_string(),
50 _ if self.operands.len() > 1 => format!("({})", op.str()),
51 _ => op.str(),
52 })
53 .collect();
54 format!("{}", pieces.join(""))
55 }
56
57 fn is_number(&self) -> bool {
58 self.operands.iter().all(|op| op.is_number())
59 }
60
61 fn to_cpp(&self) -> String {
62 let mut ops = self.operands.iter().peekable();
63 let mut res = String::new();
64 if let Some(first_op) = ops.peek()
66 && first_op.is_neg_one()
67 {
68 ops.next();
69 res += "-";
70 }
71 ops.map(|op| match op.known_expr() {
72 KnownExpr::Add(_) if self.operands.len() > 1 => format!("({})", op.to_cpp()),
73 KnownExpr::Pow(pow) if self.operands.len() > 1 => format!("({})", pow.to_cpp()),
74 KnownExpr::Rational(r) if self.operands.len() > 1 => format!("({})", r.to_cpp()),
75 _ => op.to_cpp(),
76 })
77 .enumerate()
78 .for_each(|(i, op)| {
79 if i > 0 {
80 res += " * ";
81 }
82 res += &op
83 });
84 res
85 }
86
87 fn expand(&self) -> Box<dyn Expr> {
88 let mut res: Vec<Box<dyn Expr>> = Vec::with_capacity(self.operands.len());
97 res.push(Integer::new_box(1));
98
99 for op in &self.operands {
100 let op = op.expand();
101
102 match KnownExpr::from_expr_box(&op) {
103 KnownExpr::Add(Add { operands }) => {
104 res = res
105 .iter()
106 .flat_map(|x| {
107 operands
108 .iter()
109 .flat_map(|expr| match KnownExpr::from_expr_box(expr) {
110 KnownExpr::Add(Add { operands }) => operands.clone(),
111 _ => vec![expr.clone_box()],
112 })
113 .map(move |addendum| x * &addendum)
114 })
115 .collect();
116 }
117 _ => {
118 for new_op in &mut res {
119 *new_op *= &op;
120 }
121 }
122 }
123 }
124
125 if res.len() == 1 {
126 res[0].clone_box()
127 } else {
128 Box::new(Add { operands: res })
129 }
130 }
131}
132
133impl Mul {
140 pub fn new_box(operands: Vec<&Box<dyn Expr>>) -> Box<dyn Expr> {
141 Box::new(Mul {
142 operands: operands.iter().copied().cloned().collect(),
143 })
144 }
145
146 pub fn new<'a, Ops: IntoIterator<Item = &'a dyn Expr>>(operands: Ops) -> Self {
147 Mul {
148 operands: operands.into_iter().map(|e| e.clone_box()).collect(),
149 }
150 }
151
152 pub fn new_move(operands: Vec<Box<dyn Expr>>) -> Self {
153 Mul { operands }
154 }
155}
156
157impl<E: Expr> std::ops::Mul<&E> for Box<dyn Expr> {
158 type Output = Box<dyn Expr>;
159
160 fn mul(self, rhs: &E) -> Self::Output {
161 &*self * rhs.get_ref()
162 }
163}
164
165impl std::ops::Mul for &Box<dyn Expr> {
166 type Output = Box<dyn Expr>;
167
168 fn mul(self, rhs: &Box<dyn Expr>) -> Self::Output {
169 &**self * &**rhs
170 }
171}
172
173impl std::ops::Mul for Box<dyn Expr> {
174 type Output = Box<dyn Expr>;
175
176 fn mul(self, rhs: Box<dyn Expr>) -> Self::Output {
177 &*self * &*rhs
178 }
179}
180
181impl std::ops::Mul<&dyn Expr> for Box<dyn Expr> {
182 type Output = Box<dyn Expr>;
183
184 fn mul(self, rhs: &dyn Expr) -> Self::Output {
185 &*self * rhs
186 }
187}
188
189impl std::ops::Mul<&Box<dyn Expr>> for Box<dyn Expr> {
190 type Output = Box<dyn Expr>;
191
192 fn mul(self, rhs: &Box<dyn Expr>) -> Self::Output {
193 &*self * &**rhs
194 }
195}
196
197impl std::ops::Mul<&dyn Expr> for &Box<dyn Expr> {
198 type Output = Box<dyn Expr>;
199
200 fn mul(self, rhs: &dyn Expr) -> Self::Output {
201 &**self * rhs
202 }
203}
204
205impl std::ops::Mul<isize> for Box<dyn Expr> {
206 type Output = Box<dyn Expr>;
207
208 fn mul(self, rhs: isize) -> Self::Output {
209 Integer::new_box(rhs) * &*self
210 }
211}
212
213impl std::ops::Add for Mul {
214 type Output = Add;
215
216 fn add(self, rhs: Self) -> Self::Output {
217 Add::new([&self as &dyn Expr, &rhs as &dyn Expr])
218 }
219}
220
221impl std::ops::MulAssign<&dyn Expr> for Box<dyn Expr> {
222 fn mul_assign(&mut self, rhs: &dyn Expr) {
223 *self = &**self * rhs;
224 }
225}
226
227impl std::ops::MulAssign<&Box<dyn Expr>> for Box<dyn Expr> {
228 fn mul_assign(&mut self, rhs: &Box<dyn Expr>) {
229 *self *= &**rhs;
230 }
231}
232
233impl std::ops::MulAssign for Box<dyn Expr> {
234 fn mul_assign(&mut self, rhs: Box<dyn Expr>) {
235 *self *= &*rhs;
236 }
237}
238
239impl std::ops::Mul for &dyn Expr {
240 type Output = Box<dyn Expr>;
241
242 fn mul(self, rhs: Self) -> Self::Output {
243 if self.is_zero() || rhs.is_zero() {
244 return Integer::new_box(0);
245 }
246 if self.is_one() {
247 return rhs.clone_box();
248 }
249 if rhs.is_one() {
250 return self.clone_box();
251 }
252
253 match (self.known_expr(), rhs.known_expr()) {
254 (KnownExpr::Rational(a), KnownExpr::Rational(b)) => return Box::new(*a * *b),
255 (KnownExpr::Integer(a), KnownExpr::Integer(b)) => {
256 return Integer::new_box(a.value * b.value);
257 }
258 (KnownExpr::Integer(a), KnownExpr::Rational(b)) => return Box::new(*b * a),
259 (KnownExpr::Rational(a), KnownExpr::Integer(b)) => return Box::new(*a * b),
260 (KnownExpr::Pow(a), KnownExpr::Pow(b))
261 if a.base().is_number()
262 && b.base().is_number()
263 && b.exponent().is_number()
264 && a.exponent() == b.exponent() =>
265 {
266 return (a.base() * b.base()).pow(&a.exponent().clone_box());
267 }
268 _ => (),
269 }
270
271 let (coeff_a, lhs) = self.get_coeff();
272 let (coeff_b, rhs) = rhs.get_coeff();
273
274 let coeff = (coeff_a) * coeff_b;
275 let mut new_operands: Vec<&Box<dyn Expr>> = Vec::new();
276
277 match (
278 KnownExpr::from_expr_box(&lhs),
279 KnownExpr::from_expr_box(&rhs),
280 ) {
281 (KnownExpr::Mul(Mul { operands: a }), KnownExpr::Mul(Mul { operands: b })) => {
282 a.iter()
283 .chain(b.iter())
284 .for_each(|op| new_operands.push(&*op));
285 }
286 (_, KnownExpr::Mul(Mul { operands })) => {
287 if !lhs.is_one() {
288 new_operands.push(&lhs);
289 }
290 operands.iter().for_each(|op| new_operands.push(&*op));
291 }
292 (KnownExpr::Mul(Mul { operands }), _) => {
293 operands.iter().for_each(|op| new_operands.push(&*op));
294 if !rhs.is_one() {
295 new_operands.push(&rhs);
296 }
297 }
298
299 _ => {
300 if !lhs.is_one() {
301 new_operands.push(&lhs);
302 }
303 if !rhs.is_one() {
304 new_operands.push(&rhs);
305 }
306 }
307 }
308 let coeff = coeff.simplify();
309 if !coeff.is_one() {
310 new_operands.insert(0, &coeff);
311 }
312
313 let mut operands_exponents: IndexMap<Box<dyn Expr>, Box<dyn Expr>> = IndexMap::new();
314
315 for op in new_operands
316 .iter()
317 .flat_map(|op| match op.known_expr() {
319 KnownExpr::Mul(Mul { operands }) => operands.clone(),
320 KnownExpr::Pow(Pow { base, exponent })
321 if matches!(base.known_expr(), KnownExpr::Mul(Mul { .. })) =>
322 {
323 let mul = base.as_mul().unwrap();
324 mul.operands.iter().map(|op| op.pow(exponent)).collect()
325 }
326 _ => vec![op.clone_box()],
327 })
328 {
329 let (expr, exponent) = op.get_exponent();
330 let entry = operands_exponents
331 .entry(expr)
332 .or_insert(Integer::zero_box());
333 *entry += exponent;
334 }
335 let mut new_operands = Vec::with_capacity(operands_exponents.len());
336
337 for (expr, exponent) in operands_exponents {
338 if exponent.is_zero() {
339 continue;
340 }
341
342 if exponent.is_one() {
343 new_operands.push(expr);
344 } else {
345 new_operands.push(Box::new(Pow {
346 base: expr,
347 exponent,
348 }));
349 }
350 }
351
352 if new_operands.len() == 0 {
353 return Integer::one_box();
354 }
355
356 if new_operands.len() == 1 {
357 return new_operands[0].clone_box();
358 }
359
360 Box::new(Mul {
361 operands: new_operands,
362 })
363 }
364}
365
366impl std::ops::Mul<Box<dyn Expr>> for &dyn Expr {
367 type Output = Box<dyn Expr>;
368
369 fn mul(self, rhs: Box<dyn Expr>) -> Self::Output {
370 self * &*rhs
371 }
372}
373
374impl std::ops::Div for &dyn Expr {
375 type Output = Box<dyn Expr>;
376
377 fn div(self, rhs: Self) -> Self::Output {
378 self * rhs.ipow(-1)
379 }
380}
381impl std::ops::Div<&dyn Expr> for Box<dyn Expr> {
382 type Output = Box<dyn Expr>;
383
384 fn div(self, rhs: &dyn Expr) -> Self::Output {
385 &*self / rhs
386 }
387}
388
389impl<E: Expr> std::ops::Div<&E> for Box<dyn Expr> {
390 type Output = Box<dyn Expr>;
391
392 fn div(self, rhs: &E) -> Self::Output {
393 &*self / rhs.get_ref()
394 }
395}
396
397impl std::ops::Div for Box<dyn Expr> {
398 type Output = Box<dyn Expr>;
399
400 fn div(self, rhs: Box<dyn Expr>) -> Self::Output {
401 &*self / &*rhs
402 }
403}
404
405impl std::ops::Div<&Box<dyn Expr>> for Box<dyn Expr> {
406 type Output = Box<dyn Expr>;
407
408 fn div(self, rhs: &Box<dyn Expr>) -> Self::Output {
409 &*self / &**rhs
410 }
411}
412
413impl std::ops::DivAssign<&dyn Expr> for Box<dyn Expr> {
414 fn div_assign(&mut self, rhs: &dyn Expr) {
415 *self = &**self / rhs
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use crate::{symbol, symbols};
422
423 use super::*;
424
425 #[test]
426 fn test_srepr() {
427 let a = Symbol::new_box("a");
428 let b = Symbol::new_box("b");
429 let c = Symbol::new_box("c");
430 let d = Symbol::new_box("d");
431
432 let expr = a * b * c * d;
433 let expected = "Mul(Symbol(a), Symbol(b), Symbol(c), Symbol(d))";
434
435 assert_eq!(expr.srepr(), expected);
436 }
437
438 #[test]
439 fn test_srepr_advanced() {
440 let c = Symbol::new_box("c");
441 let u = Symbol::new_box("u");
442 let laplacian = Symbol::new_box("laplacian");
443
444 let expr = -c.ipow(2) * laplacian * u;
445 let expected = "Mul(Integer(-1), Pow(Symbol(c), Integer(2)), Symbol(Δ), Symbol(u))";
446
447 assert_eq!(expr.srepr(), expected);
448 }
449
450 #[test]
451 fn test_srepr_difficult() {
452 let c = &Symbol::new_box("c");
453 let u = &Symbol::new_box("u");
454 let t = &Symbol::new_box("t");
455 let laplacian = &Symbol::new_box("laplacian");
456
457 let expr = &(Diff::new(u, vec![t, t]) - c.ipow(2) * laplacian * u);
458 let expected = "Add(Diff(Symbol(u), ((Symbol(t), 2))), Mul(Integer(-1), Pow(Symbol(c), Integer(2)), Symbol(Δ), Symbol(u)))";
459
460 assert_eq!(expr.srepr(), expected);
461 }
462
463 #[test]
464 fn test_div() {
465 let a = Symbol::new_box("a");
466 let b = Symbol::new_box("b");
467 let c = Symbol::new_box("c");
468 let expr = (a - b) / c;
469 assert_eq!(
470 expr.srepr(),
471 "Mul(Add(Symbol(a), Mul(Integer(-1), Symbol(b))), Pow(Symbol(c), Integer(-1)))"
472 );
473 }
474
475 #[test]
476 fn test_div_of_product_simplifies() {
477 let [a, b, c] = symbols!("a", "b", "c");
478
479 assert_eq!(&(a * b / (a * c)), &(b / c));
480 }
481
482 #[test]
483 #[ignore]
484 fn test_simplify_frac_mul() {
485 let expr = Mul::new_move(vec![Rational::new_box(1, 2), Rational::new_box(1, 2)]);
486
487 assert_eq!(expr.simplify().srepr(), "")
488 }
489
490 #[test]
491 fn test_weird_issue() {
492 let a = symbol!("a");
493 let expr = (a - Integer::new(1).get_ref()) * a;
494 let expr = expr.subs(&[[a.clone_box(), Rational::new_box(1, 2)]]);
495
496 assert_eq!(&expr.expand().simplify(), &Rational::new_box(-1, 4))
497 }
498}