opensrdk_symbolic_computation/expression/operators/
mul.rs

1use crate::{BracketsLevel, ConstantValue, Expression, ExpressionArray, TranscendentalExpression};
2use std::{collections::HashMap, ops::Mul};
3
4impl Mul<Expression> for Expression {
5    type Output = Self;
6
7    fn mul(self, rhs: Expression) -> Self::Output {
8        if !self.is_same_size(&rhs) {
9            panic!("Cannot add expressions of different sizes");
10        }
11        // Merge constant
12
13        if let (Expression::PartialVariable(vl), Expression::PartialVariable(vr)) = (&self, &rhs) {
14            // if vl.sizes() == vr.sizes() {
15            //     panic!("Mistach Sizes of Variables");
16            // }
17
18            return Expression::PartialVariable(ExpressionArray::from_factory(
19                vr.sizes().to_vec(),
20                |indices| vl[indices].clone().mul(vr[indices].clone()),
21            ));
22        }
23
24        if let Expression::PartialVariable(vr) = &rhs {
25            return Expression::PartialVariable(ExpressionArray::from_factory(
26                vr.sizes().to_vec(),
27                |indices| self.clone().mul(vr[indices].clone()),
28            ));
29        }
30
31        if let Expression::PartialVariable(vl) = &self {
32            return Expression::PartialVariable(ExpressionArray::from_factory(
33                vl.sizes().to_vec(),
34                |indices| vl[indices].clone().mul(rhs.clone()),
35            ));
36        }
37
38        if let Expression::Constant(vl) = &self {
39            if vl == &ConstantValue::Scalar(0.0) {
40                return 0.0.into();
41            }
42            if vl == &ConstantValue::Scalar(1.0) {
43                return rhs;
44            }
45            if let Expression::Constant(vr) = rhs {
46                return vl.mul(vr).into();
47            }
48        }
49        if let Expression::Constant(vr) = &rhs {
50            if vr == &ConstantValue::Scalar(0.0) {
51                return 0.0.into();
52            }
53            if vr == &ConstantValue::Scalar(1.0) {
54                return self;
55            }
56        }
57        // Merge pow
58        if let Expression::Transcendental(vl) = &self {
59            if let TranscendentalExpression::Pow(vl, el) = vl.as_ref() {
60                if let Expression::Transcendental(vr) = &rhs {
61                    if let TranscendentalExpression::Pow(vr, er) = vr.as_ref() {
62                        if vl.as_ref() == vr.as_ref() {
63                            return vl.clone().pow(*el.clone() + *er.clone());
64                        }
65                    }
66                }
67                if vl.as_ref() == &rhs {
68                    let one: Expression = 1.0.into();
69                    return vl.clone().pow(*el.clone() + one);
70                }
71            }
72        }
73        if let Expression::Transcendental(vr) = &rhs {
74            if let TranscendentalExpression::Pow(vr, er) = vr.as_ref() {
75                if vr.as_ref() == &self {
76                    let one: Expression = 1.0.into();
77                    return vr.clone().pow(*er.clone() + one);
78                }
79            }
80        }
81
82        Expression::Mul(self.into(), rhs.into())
83    }
84}
85
86impl Mul<Expression> for f64 {
87    type Output = Expression;
88
89    fn mul(self, rhs: Expression) -> Self::Output {
90        Expression::Constant(ConstantValue::Scalar(self)) * rhs
91    }
92}
93
94impl Mul<f64> for Expression {
95    type Output = Self;
96
97    fn mul(self, rhs: f64) -> Self::Output {
98        self * Expression::Constant(ConstantValue::Scalar(rhs))
99    }
100}
101
102impl Expression {
103    pub(crate) fn diff_mul(
104        l: &Box<Expression>,
105        r: &Box<Expression>,
106        variable_ids: &[&str],
107    ) -> Vec<Expression> {
108        l.differential(variable_ids)
109            .into_iter()
110            .zip(r.differential(variable_ids).into_iter())
111            .map(|(li, ri)| li * r.as_ref().clone() + l.as_ref().clone() * ri)
112            .collect()
113    }
114
115    pub(crate) fn tex_code_mul(
116        l: &Box<Expression>,
117        r: &Box<Expression>,
118        symbols: &HashMap<&str, &str>,
119        brackets_level: BracketsLevel,
120    ) -> String {
121        let inner = format!(
122            r"{{{} \times {}}}",
123            l._tex_code(symbols, BracketsLevel::ForMul),
124            r._tex_code(symbols, BracketsLevel::ForMul)
125        );
126
127        match brackets_level {
128            BracketsLevel::None | BracketsLevel::ForMul => inner,
129            BracketsLevel::ForDiv | BracketsLevel::ForOperation => {
130                format!(r"\left({}\right)", inner)
131            }
132        }
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use std::{collections::HashMap, ops::Add};
139
140    use opensrdk_linear_algebra::sparse::SparseTensor;
141
142    use crate::Expression;
143
144    #[test]
145    fn it_works() {
146        let a1 = 5.0f64;
147        let b1 = vec![a1; 8];
148        let mut hash1 = HashMap::new();
149        hash1.insert(vec![3, 2, 1], 2.0);
150        hash1.insert(vec![1usize; 3], 3.0);
151        hash1.insert(vec![4usize; 3], 4.0);
152        hash1.insert(vec![5usize; 3], 2.0);
153        let c1 = SparseTensor::from(vec![6usize; 3], hash1).unwrap();
154
155        let ea1 = Expression::from(a1);
156        let eb1 = Expression::from(b1.clone());
157        let ec1 = Expression::from(c1.clone());
158
159        let a2 = 5.0f64;
160        let b2 = vec![a2; 8];
161        let mut hash2 = HashMap::new();
162        hash2.insert(vec![3usize; 3], 2.0);
163        hash2.insert(vec![1usize; 3], 3.0);
164        hash2.insert(vec![2, 1, 1], 4.0);
165        hash2.insert(vec![5usize; 3], 2.0);
166        let c2 = SparseTensor::from(vec![6usize; 3], hash2).unwrap();
167
168        let ea2 = Expression::from(a2);
169        let eb2 = Expression::from(b2.clone());
170        let ec2 = Expression::from(c2.clone());
171
172        let ea = ea1 * ea2;
173        let eb = eb1 * eb2;
174        let ec = ec1 * ec2;
175
176        let a = Expression::from(a1 * a2);
177        let b = Expression::from(
178            b1.iter()
179                .enumerate()
180                .map(|(i, j)| j * b2[i])
181                .collect::<Vec<f64>>(),
182        );
183        let c = Expression::from(c1 * c2);
184
185        assert_eq!(ea, a);
186        assert_eq!(eb, b);
187        assert_eq!(ec, c);
188    }
189}