opensrdk_linear_algebra/tensor/sparse/operators/
div.rs1use crate::{
2 indices_cartesian_product,
3 number::{c64, Number},
4 sparse::SparseTensor,
5};
6use rayon::prelude::*;
7use std::ops::{Div, DivAssign};
8
9fn div_scalar<T>(lhs: T, rhs: SparseTensor<T>) -> SparseTensor<T>
10where
11 T: Number,
12{
13 let mut rhs = rhs;
14
15 rhs.elems
16 .par_iter_mut()
17 .map(|r| {
18 *r.1 /= lhs;
19 })
20 .collect::<Vec<_>>();
21
22 rhs
23}
24
25fn div<T>(lhs: SparseTensor<T>, rhs: &SparseTensor<T>) -> SparseTensor<T>
26where
27 T: Number,
28{
29 if !lhs.is_same_size(rhs) {
30 panic!("Dimension mismatch.")
31 }
32 let mut lhs = lhs;
33
34 indices_cartesian_product(&lhs.sizes)
35 .into_iter()
36 .for_each(|k| {
37 if !lhs.elems.contains_key(&k) {
38 return;
39 }
40 lhs[&k] /= rhs[&k];
41 });
42
43 lhs
44}
45
46macro_rules! impl_div_scalar {
49 {$t: ty} => {
50 impl Div<SparseTensor<$t>> for $t {
51 type Output = SparseTensor<$t>;
52
53 fn div(self, rhs: SparseTensor<$t>) -> Self::Output {
54 div_scalar(self, rhs)
55 }
56 }
57
58 impl Div<SparseTensor<$t>> for &$t {
59 type Output = SparseTensor<$t>;
60
61 fn div(self, rhs: SparseTensor<$t>) -> Self::Output {
62 div_scalar(*self, rhs)
63 }
64 }
65 }
66}
67
68impl_div_scalar! {f64}
69impl_div_scalar! {c64}
70
71impl<T> Div<T> for SparseTensor<T>
74where
75 T: Number,
76{
77 type Output = SparseTensor<T>;
78
79 fn div(self, rhs: T) -> Self::Output {
80 div_scalar(rhs, self)
81 }
82}
83
84impl<T> Div<&T> for SparseTensor<T>
85where
86 T: Number,
87{
88 type Output = SparseTensor<T>;
89
90 fn div(self, rhs: &T) -> Self::Output {
91 div_scalar(*rhs, self)
92 }
93}
94
95impl<T> Div<SparseTensor<T>> for SparseTensor<T>
98where
99 T: Number,
100{
101 type Output = SparseTensor<T>;
102
103 fn div(self, rhs: SparseTensor<T>) -> Self::Output {
104 div(self, &rhs)
105 }
106}
107
108impl<T> Div<&SparseTensor<T>> for SparseTensor<T>
109where
110 T: Number,
111{
112 type Output = SparseTensor<T>;
113
114 fn div(self, rhs: &SparseTensor<T>) -> Self::Output {
115 div(self, rhs)
116 }
117}
118
119impl<T> Div<SparseTensor<T>> for &SparseTensor<T>
120where
121 T: Number,
122{
123 type Output = SparseTensor<T>;
124
125 fn div(self, rhs: SparseTensor<T>) -> Self::Output {
126 div(rhs, self)
127 }
128}
129
130impl<T> DivAssign<SparseTensor<T>> for SparseTensor<T>
133where
134 T: Number,
135{
136 fn div_assign(&mut self, rhs: SparseTensor<T>) {
137 *self = self as &Self / rhs;
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use std::collections::HashMap;
144
145 use super::*;
146
147 #[test]
148 fn div_scalar() {
149 let mut lhs = SparseTensor::new(vec![3, 2, 2]);
150 lhs[&[0, 0, 0]] = 2.0;
151 lhs[&[0, 0, 1]] = 4.0;
152 lhs[&[1, 1, 0]] = 2.0;
153 lhs[&[1, 1, 1]] = 4.0;
154 lhs[&[2, 0, 0]] = 2.0;
155 lhs[&[2, 0, 1]] = 4.0;
156
157 let mut hash2 = HashMap::new();
158
159 hash2.insert(vec![0usize, 0, 0], 1.0);
160 hash2.insert(vec![0usize, 0, 1], 2.0);
161
162 hash2.insert(vec![1usize, 1, 0], 1.0);
163 hash2.insert(vec![1usize, 1, 0], 1.0);
164 hash2.insert(vec![1usize, 1, 1], 2.0);
165
166 hash2.insert(vec![2usize, 0, 0], 1.0);
167 hash2.insert(vec![2usize, 0, 1], 2.0);
168
169 let rhs = SparseTensor::from(vec![3, 2, 2], hash2).unwrap();
170
171 let res = lhs / 2.0;
172
173 assert_eq!(res, rhs);
174 }
175
176 #[test]
177 fn div() {
178 let mut lhs = SparseTensor::new(vec![3, 2, 2]);
179 lhs[&[0, 0, 0]] = 2.0;
180 lhs[&[0, 0, 1]] = 4.0;
181 lhs[&[1, 1, 0]] = 2.0;
182 lhs[&[1, 1, 1]] = 4.0;
183 lhs[&[2, 0, 0]] = 2.0;
184 lhs[&[2, 0, 1]] = 4.0;
185
186 let mut rhs = SparseTensor::new(vec![3, 2, 2]);
187 rhs[&[0, 0, 0]] = 1.0;
188 rhs[&[0, 0, 1]] = 2.0;
189 rhs[&[0, 1, 0]] = 1.0;
190 rhs[&[0, 1, 1]] = 1.0;
191
192 rhs[&[1, 0, 0]] = 1.0;
193 rhs[&[1, 0, 1]] = 1.0;
194 rhs[&[1, 1, 0]] = 1.0;
195 rhs[&[1, 1, 1]] = 2.0;
196
197 rhs[&[2, 0, 0]] = 2.0;
198 rhs[&[2, 0, 1]] = 4.0;
199 rhs[&[2, 1, 0]] = 1.0;
200 rhs[&[2, 1, 1]] = 1.0;
201
202 let res = lhs / rhs;
203 assert_eq!(res[&[0, 0, 0]], 2.0);
204 assert_eq!(res[&[0, 0, 1]], 2.0);
205 assert_eq!(res[&[1, 1, 0]], 2.0);
206 assert_eq!(res[&[1, 1, 1]], 2.0);
207 assert_eq!(res[&[2, 0, 0]], 1.0);
208 assert_eq!(res[&[2, 0, 1]], 1.0);
209 }
210}