opensrdk_symbolic_computation/expression/operators/
add.rs1use crate::{BracketsLevel, ConstantValue, Expression, ExpressionArray};
2use std::{collections::HashMap, ops::Add};
3
4impl Add<Expression> for Expression {
5 type Output = Self;
6
7 fn add(self, rhs: Expression) -> Self::Output {
8 if !self.is_same_size(&rhs) {
9 panic!("Cannot add expressions of different sizes");
10 }
11
12 if let (Expression::PartialVariable(vl), Expression::PartialVariable(vr)) = (&self, &rhs) {
13 return Expression::PartialVariable(ExpressionArray::from_factory(
18 vr.sizes().to_vec(),
19 |indices| vl[indices].clone().add(vr[indices].clone()),
20 ));
21 }
22
23 if let Expression::Constant(vl) = &self {
38 if vl == &ConstantValue::Scalar(0.0) {
39 return rhs;
40 }
41 if let Expression::Constant(vr) = rhs {
42 return vl.add(vr).into();
43 }
44 }
45 if let Expression::Constant(vr) = &rhs {
46 if vr == &ConstantValue::Scalar(0.0) {
47 return self;
48 }
49 }
50
51 Expression::Add(self.into(), rhs.into())
52 }
53}
54
55impl Add<f64> for Expression {
56 type Output = Self;
57
58 fn add(self, rhs: f64) -> Self::Output {
59 self + Expression::Constant(ConstantValue::Scalar(rhs))
60 }
61}
62
63impl Add<Expression> for f64 {
64 type Output = Expression;
65
66 fn add(self, rhs: Expression) -> Self::Output {
67 Expression::Constant(ConstantValue::Scalar(self)) + rhs
68 }
69}
70
71impl Expression {
72 pub(crate) fn diff_add(
73 l: &Box<Expression>,
74 r: &Box<Expression>,
75 variable_ids: &[&str],
76 ) -> Vec<Expression> {
77 l.differential(variable_ids)
78 .into_iter()
79 .zip(r.differential(variable_ids).into_iter())
80 .map(|(li, ri)| li + ri)
81 .collect()
82 }
83
84 pub(crate) fn tex_code_add(
85 l: &Box<Expression>,
86 r: &Box<Expression>,
87 symbols: &HashMap<&str, &str>,
88 brackets_level: BracketsLevel,
89 ) -> String {
90 let inner = format!(
91 "{{{} + {}}}",
92 l._tex_code(symbols, BracketsLevel::None),
93 r._tex_code(symbols, BracketsLevel::None)
94 );
95
96 match brackets_level {
97 BracketsLevel::None => inner,
98 BracketsLevel::ForMul | BracketsLevel::ForDiv | BracketsLevel::ForOperation => {
99 format!(r"\left({}\right)", inner)
100 }
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use std::{collections::HashMap, ops::Add};
108
109 use opensrdk_linear_algebra::sparse::SparseTensor;
110
111 use crate::Expression;
112
113 #[test]
114 fn it_works() {
115 let a1 = 5.0f64;
116 let b1 = vec![a1; 8];
117 let mut hash1 = HashMap::new();
118 hash1.insert(vec![3, 2, 1], 2.0);
119 hash1.insert(vec![1usize; 3], 3.0);
120 hash1.insert(vec![4usize; 3], 4.0);
121 hash1.insert(vec![5usize; 3], 2.0);
122 let c1 = SparseTensor::from(vec![6usize; 3], hash1).unwrap();
123
124 let ea1 = Expression::from(a1);
125 let eb1 = Expression::from(b1.clone());
126 let ec1 = Expression::from(c1.clone());
127
128 let a2 = 5.0f64;
129 let b2 = vec![a2; 8];
130 let mut hash2 = HashMap::new();
131 hash2.insert(vec![3usize; 3], 2.0);
132 hash2.insert(vec![1usize; 3], 3.0);
133 hash2.insert(vec![2, 1, 1], 4.0);
134 hash2.insert(vec![5usize; 3], 2.0);
135 let c2 = SparseTensor::from(vec![6usize; 3], hash2).unwrap();
136
137 let ea2 = Expression::from(a2);
138 let eb2 = Expression::from(b2.clone());
139 let ec2 = Expression::from(c2.clone());
140
141 let ea = ea1 + ea2;
142 let eb = eb1 + eb2;
143 let ec = ec1 + ec2;
144
145 let a = Expression::from(a1 + a2);
146 let b = Expression::from(
147 b1.iter()
148 .enumerate()
149 .map(|(i, j)| b2[i] + j)
150 .collect::<Vec<f64>>(),
151 );
152 let c = Expression::from(c1 + c2);
153
154 assert_eq!(ea, a);
155 assert_eq!(eb, b);
156 assert_eq!(ec, c);
157 }
159}