opensrdk_symbolic_computation/expression/operators/
mul.rs1use crate::{BracketsLevel, ConstantValue, Expression, ExpressionArray, TranscendentalExpression};
2use std::{collections::HashMap, ops::Mul};
3
4impl Mul<Expression> for Expression {
5 type Output = Self;
6
7 fn mul(self, rhs: Expression) -> Self::Output {
8 if !self.is_same_size(&rhs) {
9 panic!("Cannot add expressions of different sizes");
10 }
11 if let (Expression::PartialVariable(vl), Expression::PartialVariable(vr)) = (&self, &rhs) {
14 return Expression::PartialVariable(ExpressionArray::from_factory(
19 vr.sizes().to_vec(),
20 |indices| vl[indices].clone().mul(vr[indices].clone()),
21 ));
22 }
23
24 if let Expression::PartialVariable(vr) = &rhs {
25 return Expression::PartialVariable(ExpressionArray::from_factory(
26 vr.sizes().to_vec(),
27 |indices| self.clone().mul(vr[indices].clone()),
28 ));
29 }
30
31 if let Expression::PartialVariable(vl) = &self {
32 return Expression::PartialVariable(ExpressionArray::from_factory(
33 vl.sizes().to_vec(),
34 |indices| vl[indices].clone().mul(rhs.clone()),
35 ));
36 }
37
38 if let Expression::Constant(vl) = &self {
39 if vl == &ConstantValue::Scalar(0.0) {
40 return 0.0.into();
41 }
42 if vl == &ConstantValue::Scalar(1.0) {
43 return rhs;
44 }
45 if let Expression::Constant(vr) = rhs {
46 return vl.mul(vr).into();
47 }
48 }
49 if let Expression::Constant(vr) = &rhs {
50 if vr == &ConstantValue::Scalar(0.0) {
51 return 0.0.into();
52 }
53 if vr == &ConstantValue::Scalar(1.0) {
54 return self;
55 }
56 }
57 if let Expression::Transcendental(vl) = &self {
59 if let TranscendentalExpression::Pow(vl, el) = vl.as_ref() {
60 if let Expression::Transcendental(vr) = &rhs {
61 if let TranscendentalExpression::Pow(vr, er) = vr.as_ref() {
62 if vl.as_ref() == vr.as_ref() {
63 return vl.clone().pow(*el.clone() + *er.clone());
64 }
65 }
66 }
67 if vl.as_ref() == &rhs {
68 let one: Expression = 1.0.into();
69 return vl.clone().pow(*el.clone() + one);
70 }
71 }
72 }
73 if let Expression::Transcendental(vr) = &rhs {
74 if let TranscendentalExpression::Pow(vr, er) = vr.as_ref() {
75 if vr.as_ref() == &self {
76 let one: Expression = 1.0.into();
77 return vr.clone().pow(*er.clone() + one);
78 }
79 }
80 }
81
82 Expression::Mul(self.into(), rhs.into())
83 }
84}
85
86impl Mul<Expression> for f64 {
87 type Output = Expression;
88
89 fn mul(self, rhs: Expression) -> Self::Output {
90 Expression::Constant(ConstantValue::Scalar(self)) * rhs
91 }
92}
93
94impl Mul<f64> for Expression {
95 type Output = Self;
96
97 fn mul(self, rhs: f64) -> Self::Output {
98 self * Expression::Constant(ConstantValue::Scalar(rhs))
99 }
100}
101
102impl Expression {
103 pub(crate) fn diff_mul(
104 l: &Box<Expression>,
105 r: &Box<Expression>,
106 variable_ids: &[&str],
107 ) -> Vec<Expression> {
108 l.differential(variable_ids)
109 .into_iter()
110 .zip(r.differential(variable_ids).into_iter())
111 .map(|(li, ri)| li * r.as_ref().clone() + l.as_ref().clone() * ri)
112 .collect()
113 }
114
115 pub(crate) fn tex_code_mul(
116 l: &Box<Expression>,
117 r: &Box<Expression>,
118 symbols: &HashMap<&str, &str>,
119 brackets_level: BracketsLevel,
120 ) -> String {
121 let inner = format!(
122 r"{{{} \times {}}}",
123 l._tex_code(symbols, BracketsLevel::ForMul),
124 r._tex_code(symbols, BracketsLevel::ForMul)
125 );
126
127 match brackets_level {
128 BracketsLevel::None | BracketsLevel::ForMul => inner,
129 BracketsLevel::ForDiv | BracketsLevel::ForOperation => {
130 format!(r"\left({}\right)", inner)
131 }
132 }
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use std::{collections::HashMap, ops::Add};
139
140 use opensrdk_linear_algebra::sparse::SparseTensor;
141
142 use crate::Expression;
143
144 #[test]
145 fn it_works() {
146 let a1 = 5.0f64;
147 let b1 = vec![a1; 8];
148 let mut hash1 = HashMap::new();
149 hash1.insert(vec![3, 2, 1], 2.0);
150 hash1.insert(vec![1usize; 3], 3.0);
151 hash1.insert(vec![4usize; 3], 4.0);
152 hash1.insert(vec![5usize; 3], 2.0);
153 let c1 = SparseTensor::from(vec![6usize; 3], hash1).unwrap();
154
155 let ea1 = Expression::from(a1);
156 let eb1 = Expression::from(b1.clone());
157 let ec1 = Expression::from(c1.clone());
158
159 let a2 = 5.0f64;
160 let b2 = vec![a2; 8];
161 let mut hash2 = HashMap::new();
162 hash2.insert(vec![3usize; 3], 2.0);
163 hash2.insert(vec![1usize; 3], 3.0);
164 hash2.insert(vec![2, 1, 1], 4.0);
165 hash2.insert(vec![5usize; 3], 2.0);
166 let c2 = SparseTensor::from(vec![6usize; 3], hash2).unwrap();
167
168 let ea2 = Expression::from(a2);
169 let eb2 = Expression::from(b2.clone());
170 let ec2 = Expression::from(c2.clone());
171
172 let ea = ea1 * ea2;
173 let eb = eb1 * eb2;
174 let ec = ec1 * ec2;
175
176 let a = Expression::from(a1 * a2);
177 let b = Expression::from(
178 b1.iter()
179 .enumerate()
180 .map(|(i, j)| j * b2[i])
181 .collect::<Vec<f64>>(),
182 );
183 let c = Expression::from(c1 * c2);
184
185 assert_eq!(ea, a);
186 assert_eq!(eb, b);
187 assert_eq!(ec, c);
188 }
189}