1use std::collections::HashMap;
2
3use crate::symbol;
4use indexmap::IndexMap;
5use itertools::Itertools;
6
7use super::*;
8
9#[derive(Clone)]
10pub struct Add {
11 pub operands: Vec<Box<dyn Expr>>,
12}
13
14impl Add {
15 pub fn new_box(operands: Vec<&Box<dyn Expr>>) -> Box<dyn Expr> {
16 Box::new(Add {
17 operands: operands.iter().copied().cloned().collect(),
18 })
19 }
20
21 pub fn new_box_v2(operands: Vec<Box<dyn Expr>>) -> Box<dyn Expr> {
22 if operands.len() == 0 {
23 Integer::new_box(0)
24 } else if operands.len() == 1 {
25 operands[0].clone_box()
26 } else {
27 Box::new(Add { operands })
28 }
29 }
30
31 pub fn new<'a, Ops: IntoIterator<Item = &'a dyn Expr>>(operands: Ops) -> Self {
32 Add {
33 operands: operands.into_iter().map(|e| e.clone_box()).collect(),
34 }
35 }
36
37 pub fn new_v2(ops: Vec<Box<dyn Expr>>) -> Self {
38 Add { operands: ops }
39 }
40
41 pub fn term_coeffs(&self) -> IndexMap<Box<dyn Expr>, Rational> {
42 let mut term_coeffs: IndexMap<Box<dyn Expr>, Rational> = IndexMap::new();
43
44 for op in self.operands.iter() {
45 let (coeff, expr) = op.get_coeff();
47
48 let entry = term_coeffs.entry(expr).or_insert(Rational::zero());
49 *entry += coeff;
50 }
51 term_coeffs
52 .into_iter()
53 .filter(|(_, v)| !v.is_zero())
54 .collect()
55 }
56}
57
58impl From<Vec<&Box<dyn Expr>>> for Add {
59 fn from(value: Vec<&Box<dyn Expr>>) -> Self {
60 Add::new(value.iter().map(|x| &***x).collect::<Vec<_>>())
61 }
62}
63
64impl Expr for Add {
65 fn known_expr(&self) -> KnownExpr {
66 KnownExpr::Add(self)
67 }
68 fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
69 self.operands.iter().for_each(|e| f(&**e));
70 }
71
72 fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
73 let args: Vec<Box<dyn Expr>> = args.iter().cloned().collect();
74 Box::new(Add { operands: args })
75 }
76
77 fn clone_box(&self) -> Box<dyn Expr> {
78 Box::new(self.clone())
79 }
80
81 fn str(&self) -> String {
82 let pieces: Vec<_> = self
83 .operands
84 .iter()
85 .enumerate()
86 .map(|(i, op)| match KnownExpr::from_expr_box(op) {
87 KnownExpr::Mul(Mul { operands }) if operands.len() > 0 && i > 0 => {
88 match KnownExpr::from_expr_box(&operands[0]) {
89 KnownExpr::Integer(Integer { value: -1 }) => {
90 let mul = op.str();
91 format!(" - {}", mul[1..].to_string())
92 }
93 _ => format!(" + {}", op.str()),
100 }
101 }
102
103 KnownExpr::Integer(integer) if integer.value < 0 => {
104 format!(" - {}", op.str()[1..].to_string())
105 }
106
107 _ if i > 0 => format!(" + {}", op.str()),
108 _ => op.str(),
109 })
110 .collect();
111 format!("{}", pieces.join(""))
112 }
113
114 fn to_cpp(&self) -> String {
116 let pieces: Vec<_> = self
117 .operands
118 .iter()
119 .enumerate()
120 .map(|(i, op)| match KnownExpr::from_expr_box(op) {
121 KnownExpr::Mul(Mul { operands }) if operands.len() > 0 && i > 0 => {
122 match KnownExpr::from_expr_box(&operands[0]) {
123 KnownExpr::Integer(Integer { value: -1 }) => {
124 let mul = op.to_cpp();
125 format!(" - {}", mul[1..].to_string())
126 }
127 _ => format!(" + {}", op.to_cpp()),
128 }
129 }
130 _ if i > 0 => format!(" + {}", op.to_cpp()),
131 _ => op.to_cpp(),
132 })
133 .collect();
134 format!("{}", pieces.join(""))
135 }
136
137 fn simplify(&self) -> Box<dyn Expr> {
138 if self.operands.len() == 2 {
139 match (self.operands[0].known_expr(), self.operands[1].known_expr()) {
140 (KnownExpr::Integer(a), KnownExpr::Integer(b)) => {
141 return Integer::new_box(a.value + b.value);
142 }
143 (KnownExpr::Rational(r1), KnownExpr::Rational(r2)) => return (r1 + r2).simplify(),
144 (KnownExpr::Integer(a), KnownExpr::Rational(r2)) => return a + r2,
145 (KnownExpr::Rational(r1), KnownExpr::Integer(b)) => return r1 + b,
146 _ => (),
147 }
148 }
149 return self.from_args(
150 self.args()
151 .iter()
152 .map(|a| a.map_expr(&|e| e.simplify()))
153 .collect(),
154 );
155 }
160
161 fn simplify_with_dimension(&self, dim: usize) -> Box<dyn Expr> {
162 let expr = self;
163
164 match expr.known_expr() {
165 KnownExpr::Add(Add { operands }) => {
166 let operands = operands
167 .iter()
168 .map(|op| op.simplify_with_dimension(dim))
169 .collect_vec();
170 let mut snd_ord_spatial_derivatives: HashMap<&Box<dyn Expr>, Vec<usize>> =
173 HashMap::new();
174 let mut res_ops: Vec<Box<dyn Expr>> = Vec::with_capacity(operands.len());
175
176 for op in &operands {
177 match op.known_expr() {
178 KnownExpr::Diff(Diff { f, vars }) => {
179 let entry = snd_ord_spatial_derivatives.entry(f).or_insert(vec![0; 3]);
180
181 if vars.len() == 1 {
182 let (var, order) = vars.iter().next().unwrap();
183 if *order != 2 {
184 res_ops.push(op.clone());
185 continue;
186 }
187 match var.name.as_str() {
188 "x" => entry[0] += 1,
189 "y" => entry[1] += 1,
190 "z" => entry[2] += 1,
191 _ => res_ops.push(op.clone()),
192 }
193 } else {
194 res_ops.push(op.clone());
195 }
196 }
197
198 _ => res_ops.push(op.clone()),
199 }
200 }
201
202 let laplacian = symbol!("laplacian");
203
204 for (f, mut counts) in snd_ord_spatial_derivatives {
205 let min = *counts[0..dim].iter().min().unwrap();
206
207 if min == 1 {
208 res_ops.push(laplacian * f.get_ref());
209 } else if min >= 1 {
210 res_ops.push(laplacian * min * f.get_ref());
211 }
212 for k in 0..dim {
213 counts[k] -= min;
214 }
215 for k in 0..dim {
216 let count = counts[k];
217 if count > 0 {
218 todo!()
219 }
220 }
221 }
222
223 Add::new_box_v2(res_ops).simplify()
224 }
225 _ => expr.simplify_with_dimension(dim),
226 }
227 }
228
229 fn expand(&self) -> Box<dyn Expr> {
230 let operands: Vec<Box<dyn Expr>> = self
231 .operands
232 .iter()
233 .flat_map(|op| {
234 let op = op.expand();
235
236 match KnownExpr::from_expr_box(&op) {
237 KnownExpr::Add(Add { operands }) => operands.clone(),
238 _ => vec![op.clone_box()],
239 }
240 })
241 .collect();
242
243 if operands.len() == 0 {
244 Integer::new_box(0)
245 } else if operands.len() == 1 {
246 operands[0].clone()
247 } else {
248 Box::new(Add { operands })
249 }
250 }
251
252 fn get_ref<'a>(&'a self) -> &'a dyn Expr {
253 self as &dyn Expr
254 }
255
256 fn terms<'a>(&'a self) -> Box<dyn Iterator<Item = &'a dyn Expr> + 'a> {
257 Box::new(self.operands.iter().map(|o| &**o))
258 }
259}
260
261impl std::fmt::Debug for Add {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 write!(f, "{}", self.srepr())
264 }
265}
266
267impl std::ops::Add for &dyn Expr {
268 type Output = Box<dyn Expr>;
269
270 fn add(self, rhs: Self) -> Self::Output {
271 if self.is_zero() {
272 return rhs.clone_box();
273 }
274 if rhs.is_zero() {
275 return self.clone_box();
276 }
277 if self == &*-rhs {
278 return Integer::new_box(0);
279 }
280
281 let mut term_coeffs: IndexMap<Box<dyn Expr>, Rational> = IndexMap::new();
282
283 match (KnownExpr::from_expr(self), KnownExpr::from_expr(rhs)) {
284 (
285 KnownExpr::Integer(Integer { value: a }),
286 KnownExpr::Integer(Integer { value: b }),
287 ) => return Integer::new_box(a + b),
288 (KnownExpr::Rational(r1), KnownExpr::Rational(r2)) => return Box::new(r1 + r2),
289 (KnownExpr::Add(Add { operands: ops_a }), KnownExpr::Add(Add { operands: ops_b })) => {
290 ops_a
291 .iter()
292 .chain(ops_b.iter())
293 .filter(|x| !x.is_zero())
294 .for_each(|op| {
295 let (coeff, expr) = op.get_coeff();
296 let entry = term_coeffs.entry(expr.clone()).or_insert(Rational::zero());
297 *entry += coeff;
298 });
299 }
300 (KnownExpr::Add(Add { operands }), _) => {
301 operands
302 .iter()
303 .map(|e| e.get_ref())
304 .chain(iter::once(rhs))
305 .filter(|x| !x.is_zero())
306 .for_each(|op| {
307 let (coeff, expr) = op.get_coeff();
308 let entry = term_coeffs.entry(expr.clone()).or_insert(Rational::zero());
309 *entry += coeff;
310 });
311 }
312 (_, KnownExpr::Add(Add { operands })) => {
313 iter::once(self)
314 .chain(operands.iter().map(|e| e.get_ref()))
315 .filter(|x| !x.is_zero())
316 .for_each(|op| {
317 let (coeff, expr) = op.get_coeff();
318 let entry = term_coeffs.entry(expr.clone()).or_insert(Rational::zero());
319 *entry += coeff;
320 });
321 }
322 _ => {
323 iter::once(self)
324 .chain(iter::once(rhs))
325 .filter(|x| !x.is_zero())
326 .for_each(|op| {
327 let (coeff, expr) = op.get_coeff();
328 let entry = term_coeffs.entry(expr.clone()).or_insert(Rational::zero());
329 *entry += coeff;
330 });
331 }
332 };
333
334 let mut operands: Vec<Box<dyn Expr>> = Vec::with_capacity(term_coeffs.len());
335
336 for (expr, coeff) in term_coeffs {
337 if coeff.is_zero() {
338 continue;
339 }
340
341 if !coeff.is_one() {
342 operands.push(coeff.simplify() * expr);
343 } else {
344 operands.push(expr)
345 }
346 }
347
348 if operands.len() == 0 {
349 Integer::new_box(0)
350 } else if operands.len() == 1 {
351 operands[0].clone_box()
352 } else {
353 Box::new(Add { operands })
354 }
355 }
356}
357impl std::ops::Add for &Box<dyn Expr> {
358 type Output = Box<dyn Expr>;
359
360 fn add(self, rhs: &Box<dyn Expr>) -> Self::Output {
361 &**self + &**rhs
362 }
363}
364
365impl std::ops::Add<Box<dyn Expr>> for &dyn Expr {
366 type Output = Box<dyn Expr>;
367
368 fn add(self, rhs: Box<dyn Expr>) -> Self::Output {
369 self + &*rhs
370 }
371}
372
373impl std::ops::Add for Box<dyn Expr> {
374 type Output = Box<dyn Expr>;
375
376 fn add(self, rhs: Box<dyn Expr>) -> Self::Output {
377 &*self + &*rhs
378 }
379}
380
381impl std::ops::AddAssign for Box<dyn Expr> {
382 fn add_assign(&mut self, rhs: Self) {
383 *self = self.get_ref() + rhs.get_ref();
384 }
385}
386impl<'a> From<&'a Add> for &'a dyn Expr {
387 fn from(value: &'a Add) -> Self {
388 value as &'a dyn Expr
389 }
390}
391
392impl std::ops::Mul<&dyn Expr> for Add {
393 type Output = Mul;
394
395 fn mul(self, rhs: &dyn Expr) -> Self::Output {
396 Mul::new([&self as &dyn Expr, rhs])
397 }
398}
399
400impl std::ops::Sub for &dyn Expr {
401 type Output = Box<dyn Expr>;
402
403 fn sub(self, rhs: Self) -> Self::Output {
404 self + &*(-rhs)
405 }
406}
407impl std::ops::Sub for &Box<dyn Expr> {
408 type Output = Box<dyn Expr>;
409
410 fn sub(self, rhs: &Box<dyn Expr>) -> Self::Output {
411 &**self - &**rhs
412 }
413}
414
415impl std::ops::Sub<Box<dyn Expr>> for &Box<dyn Expr> {
416 type Output = Box<dyn Expr>;
417
418 fn sub(self, rhs: Box<dyn Expr>) -> Self::Output {
419 &**self - &*rhs
420 }
421}
422
423impl std::ops::Sub<&Box<dyn Expr>> for Box<dyn Expr> {
424 type Output = Box<dyn Expr>;
425
426 fn sub(self, rhs: &Box<dyn Expr>) -> Self::Output {
427 &*self - &**rhs
428 }
429}
430
431impl std::ops::Sub for Box<dyn Expr> {
432 type Output = Box<dyn Expr>;
433
434 fn sub(self, rhs: Box<dyn Expr>) -> Self::Output {
435 &*self - &*rhs
436 }
437}
438
439impl std::ops::SubAssign<&dyn Expr> for Box<dyn Expr> {
440 fn sub_assign(&mut self, rhs: &dyn Expr) {
441 *self = &**self - rhs;
442 }
443}
444
445impl std::ops::Add<&dyn Expr> for Box<dyn Expr> {
446 type Output = Box<dyn Expr>;
447
448 fn add(self, rhs: &dyn Expr) -> Self::Output {
449 &*self + rhs
450 }
451}
452
453impl std::ops::Add<isize> for Box<dyn Expr> {
454 type Output = Box<dyn Expr>;
455
456 fn add(self, rhs: isize) -> Self::Output {
457 &*self + Integer::new_box(rhs)
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 #[test]
465 fn test_srepr() {
466 let a = Symbol::new_box("a");
467 let b = Symbol::new_box("b");
468
469 let expr = a - b;
470 let expected = "Add(Symbol(a), Mul(Integer(-1), Symbol(b)))";
471
472 assert_eq!(expr.srepr(), expected);
473 }
474
475 #[test]
476 fn test_srepr_2() {
477 let a = Symbol::new_box("a");
478 let b = Symbol::new_box("b");
479 let c = Symbol::new_box("c");
480
481 let expr = a - b * c;
482 let expected = "Add(Symbol(a), Mul(Integer(-1), Symbol(b), Symbol(c)))";
483
484 assert_eq!(expr.srepr(), expected);
485 }
486
487 #[test]
488 fn test_simplify_dimension() {
489 let expr: Box<dyn Expr> = "d2u/dx2 + d2u/dy2".parse().unwrap();
490 let expected: Box<dyn Expr> = "laplacian * u".parse().unwrap();
491
492 assert_eq!(expr.simplify_with_dimension(2), expected);
493 }
494
495 #[test]
496 #[ignore]
497 fn test_simplify_dim_advanced_add() {
498 let expr: Box<dyn Expr> = "c^2 * (∂^2u / ∂x^2 + ∂^2u / ∂y^2) + source"
499 .parse()
500 .unwrap();
501 let expected: Box<dyn Expr> = "c^2 * laplacian * u + source".parse().unwrap();
502
503 assert_eq!(expr.simplify_with_dimension(2), expected);
504 }
505}