opensrdk_symbolic_computation/expression/operators/
div.rs

1use crate::{BracketsLevel, ConstantValue, Expression, ExpressionArray};
2use std::{collections::HashMap, ops::Div};
3
4impl Div<Expression> for Expression {
5    type Output = Self;
6
7    fn div(self, rhs: Expression) -> Self::Output {
8        if !self.is_same_size(&rhs) {
9            panic!("Cannot add expressions of different sizes");
10        }
11
12        if let (Expression::PartialVariable(vl), Expression::PartialVariable(vr)) = (&self, &rhs) {
13            // if vl.sizes() == vr.sizes() {
14            //     panic!("Mistach Sizes of Variables");
15            // }
16
17            return Expression::PartialVariable(ExpressionArray::from_factory(
18                vr.sizes().to_vec(),
19                |indices| vl[indices].clone().div(vr[indices].clone()),
20            ));
21        }
22
23        // if let Expression::PartialVariable(vr) = &rhs {
24        //     return Expression::PartialVariable(ExpressionArray::from_factory(
25        //         vr.sizes().to_vec(),
26        //         |indices| self.clone().div(vr[indices].clone()),
27        //     ));
28        // }
29
30        // if let Expression::PartialVariable(vl) = &self {
31        //     return Expression::PartialVariable(ExpressionArray::from_factory(
32        //         vl.sizes().to_vec(),
33        //         |indices| vl[indices].clone().div(rhs.clone()),
34        //     ));
35        // }
36
37        if let Expression::Constant(vr) = &rhs {
38            if vr == &ConstantValue::Scalar(1.0) {
39                return self;
40            }
41            if let Expression::Constant(vl) = self {
42                return vl.div(vr).into();
43            }
44        }
45
46        Expression::Div(self.into(), rhs.into())
47    }
48}
49
50impl Div<f64> for Expression {
51    type Output = Self;
52
53    fn div(self, rhs: f64) -> Self::Output {
54        self / Expression::Constant(ConstantValue::Scalar(rhs))
55    }
56}
57
58impl Div<Expression> for f64 {
59    type Output = Expression;
60
61    fn div(self, rhs: Expression) -> Self::Output {
62        Expression::Constant(ConstantValue::Scalar(self)) / rhs
63    }
64}
65
66impl Expression {
67    pub(crate) fn diff_div(
68        l: &Box<Expression>,
69        r: &Box<Expression>,
70        variable_ids: &[&str],
71    ) -> Vec<Expression> {
72        l.differential(variable_ids)
73            .into_iter()
74            .zip(r.differential(variable_ids).into_iter())
75            .map(|(li, ri)| {
76                (li * r.as_ref().clone()
77                    - l.as_ref().clone() * ri / r.as_ref().clone().pow(2.0.into()))
78            })
79            .collect()
80    }
81
82    pub(crate) fn tex_code_div(
83        l: &Box<Expression>,
84        r: &Box<Expression>,
85        symbols: &HashMap<&str, &str>,
86        brackets_level: BracketsLevel,
87    ) -> String {
88        let inner = format!(
89            "{{{} / {}}}",
90            l._tex_code(symbols, BracketsLevel::ForDiv),
91            r._tex_code(symbols, BracketsLevel::ForDiv)
92        );
93
94        match brackets_level {
95            BracketsLevel::None | BracketsLevel::ForMul => inner,
96            BracketsLevel::ForDiv | BracketsLevel::ForOperation => {
97                format!(r"\left({}\right)", inner)
98            }
99        }
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use std::{collections::HashMap, ops::Add};
106
107    use opensrdk_linear_algebra::sparse::SparseTensor;
108
109    use crate::Expression;
110
111    #[test]
112    fn it_works() {
113        let a1 = 5.0f64;
114        let b1 = vec![a1; 8];
115        let mut hash1 = HashMap::new();
116        hash1.insert(vec![3, 2, 1], 2.0);
117        hash1.insert(vec![1usize; 3], 3.0);
118        hash1.insert(vec![4usize; 3], 4.0);
119        hash1.insert(vec![5usize; 3], 2.0);
120        let c1 = SparseTensor::from(vec![6usize; 3], hash1).unwrap();
121
122        let ea1 = Expression::from(a1);
123        let eb1 = Expression::from(b1.clone());
124        let ec1 = Expression::from(c1.clone());
125
126        let a2 = 5.0f64;
127        let b2 = vec![a2; 8];
128        let mut hash2 = HashMap::new();
129        hash2.insert(vec![3usize; 3], 2.0);
130        hash2.insert(vec![1usize; 3], 3.0);
131        hash2.insert(vec![2, 1, 1], 4.0);
132        hash2.insert(vec![5usize; 3], 2.0);
133        let c2 = SparseTensor::from(vec![6usize; 3], hash2).unwrap();
134
135        let ea2 = Expression::from(a2);
136        let eb2 = Expression::from(b2.clone());
137        let ec2 = Expression::from(c2.clone());
138
139        let ea = ea1 / ea2;
140        let eb = eb1 / eb2;
141        let ec = ec1 / ec2;
142
143        let a = Expression::from(a1 / a2);
144        let b = Expression::from(
145            b1.iter()
146                .enumerate()
147                .map(|(i, j)| j / b2[i])
148                .collect::<Vec<f64>>(),
149        );
150        let c = Expression::from(c1 / c2);
151
152        assert_eq!(ea, a);
153        assert_eq!(eb, b);
154        assert_eq!(ec, c);
155    }
156}