1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
use crate::TensorExpression;
use std::ops::Add;
impl Add<TensorExpression> for TensorExpression {
type Output = Self;
fn add(self, rhs: TensorExpression) -> Self::Output {
if let TensorExpression::Constant(vl) = &self {
if let TensorExpression::Constant(vr) = rhs {
return Self::Constant(vl + vr);
}
}
if let TensorExpression::Zero = &self {
return rhs;
}
if let TensorExpression::Zero = &rhs {
return self;
}
TensorExpression::Add(self.into(), rhs.into())
}
}
impl TensorExpression {
pub(crate) fn diff_add(
symbols: &[&str],
l: &Box<TensorExpression>,
r: &Box<TensorExpression>,
) -> Vec<TensorExpression> {
l.differential(symbols)
.into_iter()
.zip(r.differential(symbols).into_iter())
.map(|(li, ri)| li + ri)
.collect()
}
pub(crate) fn rust_code_add(
l: &Box<TensorExpression>,
r: &Box<TensorExpression>,
parentheses: bool,
) -> String {
if parentheses {
format!("({} + {})", l._rust_code(true), r._rust_code(true))
} else {
format!("{} + {}", l._rust_code(true), r._rust_code(true))
}
}
}