opensrdk_linear_algebra/tensor/sparse/operators/
mul.rs1use crate::{
2 indices_cartesian_product,
3 number::{c64, Number},
4 sparse::SparseTensor,
5};
6use rayon::prelude::*;
7use std::ops::{Mul, MulAssign};
8
9fn mul_scalar<T>(lhs: T, rhs: SparseTensor<T>) -> SparseTensor<T>
10where
11 T: Number,
12{
13 let mut rhs = rhs;
14
15 rhs.elems
16 .par_iter_mut()
17 .map(|r| {
18 *r.1 *= lhs;
19 })
20 .collect::<Vec<_>>();
21
22 rhs
23}
24
25fn mul<T>(lhs: SparseTensor<T>, rhs: &SparseTensor<T>) -> SparseTensor<T>
26where
27 T: Number,
28{
29 if !lhs.is_same_size(rhs) {
30 panic!("Dimension mismatch.")
31 }
32 let mut lhs = lhs;
33
34 indices_cartesian_product(&lhs.sizes)
35 .into_iter()
36 .for_each(|k| {
37 if !lhs.elems.contains_key(&k) {
38 return;
39 }
40 if !rhs.elems.contains_key(&k) {
41 lhs.elems.remove(&k);
42 return;
43 }
44 lhs[&k] *= rhs[&k];
45 });
46
47 lhs
48}
49
50macro_rules! impl_div_scalar {
53 {$t: ty} => {
54 impl Mul<SparseTensor<$t>> for $t {
55 type Output = SparseTensor<$t>;
56
57 fn mul(self, rhs: SparseTensor<$t>) -> Self::Output {
58 mul_scalar(self, rhs)
59 }
60 }
61
62 impl Mul<SparseTensor<$t>> for &$t {
63 type Output = SparseTensor<$t>;
64
65 fn mul(self, rhs: SparseTensor<$t>) -> Self::Output {
66 mul_scalar(*self, rhs)
67 }
68 }
69 }
70}
71
72impl_div_scalar! {f64}
73impl_div_scalar! {c64}
74
75impl<T> Mul<T> for SparseTensor<T>
78where
79 T: Number,
80{
81 type Output = SparseTensor<T>;
82
83 fn mul(self, rhs: T) -> Self::Output {
84 mul_scalar(rhs, self)
85 }
86}
87
88impl<T> Mul<&T> for SparseTensor<T>
89where
90 T: Number,
91{
92 type Output = SparseTensor<T>;
93
94 fn mul(self, rhs: &T) -> Self::Output {
95 mul_scalar(*rhs, self)
96 }
97}
98
99impl<T> Mul<SparseTensor<T>> for SparseTensor<T>
102where
103 T: Number,
104{
105 type Output = SparseTensor<T>;
106
107 fn mul(self, rhs: SparseTensor<T>) -> Self::Output {
108 mul(self, &rhs)
109 }
110}
111
112impl<T> Mul<&SparseTensor<T>> for SparseTensor<T>
113where
114 T: Number,
115{
116 type Output = SparseTensor<T>;
117
118 fn mul(self, rhs: &SparseTensor<T>) -> Self::Output {
119 mul(self, rhs)
120 }
121}
122
123impl<T> Mul<SparseTensor<T>> for &SparseTensor<T>
124where
125 T: Number,
126{
127 type Output = SparseTensor<T>;
128
129 fn mul(self, rhs: SparseTensor<T>) -> Self::Output {
130 mul(rhs, self)
131 }
132}
133
134impl<T> MulAssign<SparseTensor<T>> for SparseTensor<T>
137where
138 T: Number,
139{
140 fn mul_assign(&mut self, rhs: SparseTensor<T>) {
141 *self = self as &Self * rhs;
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn mul_scalar() {
151 let mut a = SparseTensor::new(vec![3, 2, 2]);
152 a[&[0, 0, 0]] = 2.0;
153 a[&[0, 0, 1]] = 4.0;
154 a[&[1, 1, 0]] = 2.0;
155 a[&[1, 1, 1]] = 4.0;
156 a[&[2, 0, 0]] = 2.0;
157 a[&[2, 0, 1]] = 4.0;
158
159 let b = 2.0 * a.clone();
160 let c = a.clone() * 2.0;
161 let d = 2.0 * a.clone();
162 let e = a * &2.0;
163
164 assert_eq!(b, c);
165 assert_eq!(c, d);
166 assert_eq!(d, e);
167 }
168
169 #[test]
170 fn mul() {
171 let mut a = SparseTensor::new(vec![3, 2, 2]);
172 a[&[0, 0, 0]] = 2.0;
173 a[&[0, 0, 1]] = 4.0;
174 a[&[1, 1, 0]] = 2.0;
175 a[&[1, 1, 1]] = 4.0;
176 a[&[2, 0, 0]] = 2.0;
177 a[&[2, 0, 1]] = 4.0;
178
179 let mut b = SparseTensor::new(vec![3, 2, 2]);
180 b[&[0, 0, 0]] = 2.0;
181 b[&[0, 0, 1]] = 4.0;
182 b[&[1, 1, 0]] = 2.0;
183 b[&[1, 1, 1]] = 4.0;
184 b[&[2, 0, 0]] = 2.0;
185 b[&[2, 0, 1]] = 4.0;
186
187 let mut c = SparseTensor::new(vec![3, 2, 2]);
188 c[&[0, 0, 0]] = 4.0;
189 c[&[0, 0, 1]] = 16.0;
190 c[&[1, 1, 0]] = 4.0;
191 c[&[1, 1, 1]] = 16.0;
192 c[&[2, 0, 0]] = 4.0;
193 c[&[2, 0, 1]] = 16.0;
194
195 let d = a.clone() * b.clone();
196
197 let e = b * a;
198
199 assert_eq!(c, d);
200 assert_eq!(d, e);
201 assert_eq!(e, c);
202 }
203}