opensrdk_symbolic_computation/expression/operators/
sub.rs

1use crate::{BracketsLevel, ConstantValue, Expression, ExpressionArray};
2use std::{collections::HashMap, ops::Sub};
3
4impl Sub<Expression> for Expression {
5    type Output = Self;
6
7    fn sub(self, rhs: Expression) -> Self::Output {
8        if !self.is_same_size(&rhs) {
9            panic!("Cannot add expressions of different sizes");
10        }
11
12        if let (Expression::PartialVariable(vl), Expression::PartialVariable(vr)) = (&self, &rhs) {
13            // if vl.sizes() == vr.sizes() {
14            //     panic!("Mistach Sizes of Variables");
15            // }
16
17            return Expression::PartialVariable(ExpressionArray::from_factory(
18                vr.sizes().to_vec(),
19                |indices| vl[indices].clone().sub(vr[indices].clone()),
20            ));
21        }
22
23        // if let Expression::PartialVariable(vr) = &rhs {
24        //     return Expression::PartialVariable(ExpressionArray::from_factory(
25        //         vr.sizes().to_vec(),
26        //         |indices| self.clone().sub(vr[indices].clone()),
27        //     ));
28        // }
29
30        // if let Expression::PartialVariable(vl) = &self {
31        //     return Expression::PartialVariable(ExpressionArray::from_factory(
32        //         vl.sizes().to_vec(),
33        //         |indices| vl[indices].clone().sub(rhs.clone()),
34        //     ));
35        // }
36
37        if let Expression::Constant(vl) = &self {
38            if vl == &ConstantValue::Scalar(0.0) {
39                return rhs;
40            }
41            if let Expression::Constant(vr) = rhs {
42                return vl.sub(vr).into();
43            }
44        }
45        if let Expression::Constant(vr) = &rhs {
46            if vr == &ConstantValue::Scalar(0.0) {
47                return self;
48            }
49        }
50
51        Self::Sub(self.into(), rhs.into())
52    }
53}
54
55impl Sub<f64> for Expression {
56    type Output = Self;
57
58    fn sub(self, rhs: f64) -> Self::Output {
59        self - Expression::Constant(ConstantValue::Scalar(rhs))
60    }
61}
62
63impl Sub<Expression> for f64 {
64    type Output = Expression;
65
66    fn sub(self, rhs: Expression) -> Self::Output {
67        Expression::Constant(ConstantValue::Scalar(self)) - rhs
68    }
69}
70
71impl Expression {
72    pub(crate) fn diff_sub(
73        l: &Box<Expression>,
74        r: &Box<Expression>,
75        variable_ids: &[&str],
76    ) -> Vec<Expression> {
77        l.differential(variable_ids)
78            .into_iter()
79            .zip(r.differential(variable_ids).into_iter())
80            .map(|(li, ri)| li - ri)
81            .collect()
82    }
83
84    pub(crate) fn tex_code_sub(
85        l: &Box<Expression>,
86        r: &Box<Expression>,
87        symbols: &HashMap<&str, &str>,
88        brackets_level: BracketsLevel,
89    ) -> String {
90        let inner = format!(
91            "{{{} - {}}}",
92            l._tex_code(symbols, BracketsLevel::None),
93            r._tex_code(symbols, BracketsLevel::None)
94        );
95
96        match brackets_level {
97            BracketsLevel::None => inner,
98            BracketsLevel::ForMul | BracketsLevel::ForDiv | BracketsLevel::ForOperation => {
99                format!(r"\left({}\right)", inner)
100            }
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use std::{collections::HashMap, ops::Add};
108
109    use opensrdk_linear_algebra::sparse::SparseTensor;
110
111    use crate::Expression;
112
113    #[test]
114    fn it_works() {
115        let a1 = 5.0f64;
116        let b1 = vec![a1; 8];
117        let mut hash1 = HashMap::new();
118        hash1.insert(vec![3, 2, 1], 2.0);
119        hash1.insert(vec![1usize; 3], 3.0);
120        hash1.insert(vec![4usize; 3], 4.0);
121        hash1.insert(vec![5usize; 3], 2.0);
122        let c1 = SparseTensor::from(vec![6usize; 3], hash1).unwrap();
123
124        let ea1 = Expression::from(a1);
125        let eb1 = Expression::from(b1.clone());
126        let ec1 = Expression::from(c1.clone());
127
128        let a2 = 5.0f64;
129        let b2 = vec![a2; 8];
130        let mut hash2 = HashMap::new();
131        hash2.insert(vec![3usize; 3], 2.0);
132        hash2.insert(vec![1usize; 3], 3.0);
133        hash2.insert(vec![2, 1, 1], 4.0);
134        hash2.insert(vec![5usize; 3], 2.0);
135        let c2 = SparseTensor::from(vec![6usize; 3], hash2).unwrap();
136
137        let ea2 = Expression::from(a2);
138        let eb2 = Expression::from(b2.clone());
139        let ec2 = Expression::from(c2.clone());
140
141        let ea = ea1 - ea2;
142        let eb = eb1 - eb2;
143        let ec = ec1 - ec2;
144
145        let a = Expression::from(a1 - a2);
146        let b = Expression::from(
147            b1.iter()
148                .enumerate()
149                .map(|(i, j)| j - b2[i])
150                .collect::<Vec<f64>>(),
151        );
152        let c = Expression::from(c1 - c2);
153
154        assert_eq!(ea, a);
155        assert_eq!(eb, b);
156        assert_eq!(ec, c);
157    }
158}