opensrdk_linear_algebra/tensor/sparse/operators/
div.rs

1use crate::{
2    indices_cartesian_product,
3    number::{c64, Number},
4    sparse::SparseTensor,
5};
6use rayon::prelude::*;
7use std::ops::{Div, DivAssign};
8
9fn div_scalar<T>(lhs: T, rhs: SparseTensor<T>) -> SparseTensor<T>
10where
11    T: Number,
12{
13    let mut rhs = rhs;
14
15    rhs.elems
16        .par_iter_mut()
17        .map(|r| {
18            *r.1 /= lhs;
19        })
20        .collect::<Vec<_>>();
21
22    rhs
23}
24
25fn div<T>(lhs: SparseTensor<T>, rhs: &SparseTensor<T>) -> SparseTensor<T>
26where
27    T: Number,
28{
29    if !lhs.is_same_size(rhs) {
30        panic!("Dimension mismatch.")
31    }
32    let mut lhs = lhs;
33
34    indices_cartesian_product(&lhs.sizes)
35        .into_iter()
36        .for_each(|k| {
37            if !lhs.elems.contains_key(&k) {
38                return;
39            }
40            lhs[&k] /= rhs[&k];
41        });
42
43    lhs
44}
45
46// Scalar and SparseTensor
47
48macro_rules! impl_div_scalar {
49  {$t: ty} => {
50      impl Div<SparseTensor<$t>> for $t {
51          type Output = SparseTensor<$t>;
52
53          fn div(self, rhs: SparseTensor<$t>) -> Self::Output {
54              div_scalar(self, rhs)
55          }
56      }
57
58      impl Div<SparseTensor<$t>> for &$t {
59          type Output = SparseTensor<$t>;
60
61          fn div(self, rhs: SparseTensor<$t>) -> Self::Output {
62              div_scalar(*self, rhs)
63          }
64      }
65  }
66}
67
68impl_div_scalar! {f64}
69impl_div_scalar! {c64}
70
71// SparseTensor and Scalar
72
73impl<T> Div<T> for SparseTensor<T>
74where
75    T: Number,
76{
77    type Output = SparseTensor<T>;
78
79    fn div(self, rhs: T) -> Self::Output {
80        div_scalar(rhs, self)
81    }
82}
83
84impl<T> Div<&T> for SparseTensor<T>
85where
86    T: Number,
87{
88    type Output = SparseTensor<T>;
89
90    fn div(self, rhs: &T) -> Self::Output {
91        div_scalar(*rhs, self)
92    }
93}
94
95// SparseTensor and SparseTensor
96
97impl<T> Div<SparseTensor<T>> for SparseTensor<T>
98where
99    T: Number,
100{
101    type Output = SparseTensor<T>;
102
103    fn div(self, rhs: SparseTensor<T>) -> Self::Output {
104        div(self, &rhs)
105    }
106}
107
108impl<T> Div<&SparseTensor<T>> for SparseTensor<T>
109where
110    T: Number,
111{
112    type Output = SparseTensor<T>;
113
114    fn div(self, rhs: &SparseTensor<T>) -> Self::Output {
115        div(self, rhs)
116    }
117}
118
119impl<T> Div<SparseTensor<T>> for &SparseTensor<T>
120where
121    T: Number,
122{
123    type Output = SparseTensor<T>;
124
125    fn div(self, rhs: SparseTensor<T>) -> Self::Output {
126        div(rhs, self)
127    }
128}
129
130// DivAssign
131
132impl<T> DivAssign<SparseTensor<T>> for SparseTensor<T>
133where
134    T: Number,
135{
136    fn div_assign(&mut self, rhs: SparseTensor<T>) {
137        *self = self as &Self / rhs;
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use std::collections::HashMap;
144
145    use super::*;
146
147    #[test]
148    fn div_scalar() {
149        let mut lhs = SparseTensor::new(vec![3, 2, 2]);
150        lhs[&[0, 0, 0]] = 2.0;
151        lhs[&[0, 0, 1]] = 4.0;
152        lhs[&[1, 1, 0]] = 2.0;
153        lhs[&[1, 1, 1]] = 4.0;
154        lhs[&[2, 0, 0]] = 2.0;
155        lhs[&[2, 0, 1]] = 4.0;
156
157        let mut hash2 = HashMap::new();
158
159        hash2.insert(vec![0usize, 0, 0], 1.0);
160        hash2.insert(vec![0usize, 0, 1], 2.0);
161
162        hash2.insert(vec![1usize, 1, 0], 1.0);
163        hash2.insert(vec![1usize, 1, 0], 1.0);
164        hash2.insert(vec![1usize, 1, 1], 2.0);
165
166        hash2.insert(vec![2usize, 0, 0], 1.0);
167        hash2.insert(vec![2usize, 0, 1], 2.0);
168
169        let rhs = SparseTensor::from(vec![3, 2, 2], hash2).unwrap();
170
171        let res = lhs / 2.0;
172
173        assert_eq!(res, rhs);
174    }
175
176    #[test]
177    fn div() {
178        let mut lhs = SparseTensor::new(vec![3, 2, 2]);
179        lhs[&[0, 0, 0]] = 2.0;
180        lhs[&[0, 0, 1]] = 4.0;
181        lhs[&[1, 1, 0]] = 2.0;
182        lhs[&[1, 1, 1]] = 4.0;
183        lhs[&[2, 0, 0]] = 2.0;
184        lhs[&[2, 0, 1]] = 4.0;
185
186        let mut rhs = SparseTensor::new(vec![3, 2, 2]);
187        rhs[&[0, 0, 0]] = 1.0;
188        rhs[&[0, 0, 1]] = 2.0;
189        rhs[&[0, 1, 0]] = 1.0;
190        rhs[&[0, 1, 1]] = 1.0;
191
192        rhs[&[1, 0, 0]] = 1.0;
193        rhs[&[1, 0, 1]] = 1.0;
194        rhs[&[1, 1, 0]] = 1.0;
195        rhs[&[1, 1, 1]] = 2.0;
196
197        rhs[&[2, 0, 0]] = 2.0;
198        rhs[&[2, 0, 1]] = 4.0;
199        rhs[&[2, 1, 0]] = 1.0;
200        rhs[&[2, 1, 1]] = 1.0;
201
202        let res = lhs / rhs;
203        assert_eq!(res[&[0, 0, 0]], 2.0);
204        assert_eq!(res[&[0, 0, 1]], 2.0);
205        assert_eq!(res[&[1, 1, 0]], 2.0);
206        assert_eq!(res[&[1, 1, 1]], 2.0);
207        assert_eq!(res[&[2, 0, 0]], 1.0);
208        assert_eq!(res[&[2, 0, 1]], 1.0);
209    }
210}