opensrdk_linear_algebra/tensor/sparse/operators/
mul.rs

1use crate::{
2    indices_cartesian_product,
3    number::{c64, Number},
4    sparse::SparseTensor,
5};
6use rayon::prelude::*;
7use std::ops::{Mul, MulAssign};
8
9fn mul_scalar<T>(lhs: T, rhs: SparseTensor<T>) -> SparseTensor<T>
10where
11    T: Number,
12{
13    let mut rhs = rhs;
14
15    rhs.elems
16        .par_iter_mut()
17        .map(|r| {
18            *r.1 *= lhs;
19        })
20        .collect::<Vec<_>>();
21
22    rhs
23}
24
25fn mul<T>(lhs: SparseTensor<T>, rhs: &SparseTensor<T>) -> SparseTensor<T>
26where
27    T: Number,
28{
29    if !lhs.is_same_size(rhs) {
30        panic!("Dimension mismatch.")
31    }
32    let mut lhs = lhs;
33
34    indices_cartesian_product(&lhs.sizes)
35        .into_iter()
36        .for_each(|k| {
37            if !lhs.elems.contains_key(&k) {
38                return;
39            }
40            if !rhs.elems.contains_key(&k) {
41                lhs.elems.remove(&k);
42                return;
43            }
44            lhs[&k] *= rhs[&k];
45        });
46
47    lhs
48}
49
50// Scalar and SparseTensor
51
52macro_rules! impl_div_scalar {
53  {$t: ty} => {
54      impl Mul<SparseTensor<$t>> for $t {
55          type Output = SparseTensor<$t>;
56
57          fn mul(self, rhs: SparseTensor<$t>) -> Self::Output {
58              mul_scalar(self, rhs)
59          }
60      }
61
62      impl Mul<SparseTensor<$t>> for &$t {
63          type Output = SparseTensor<$t>;
64
65          fn mul(self, rhs: SparseTensor<$t>) -> Self::Output {
66              mul_scalar(*self, rhs)
67          }
68      }
69  }
70}
71
72impl_div_scalar! {f64}
73impl_div_scalar! {c64}
74
75// SparseTensor and Scalar
76
77impl<T> Mul<T> for SparseTensor<T>
78where
79    T: Number,
80{
81    type Output = SparseTensor<T>;
82
83    fn mul(self, rhs: T) -> Self::Output {
84        mul_scalar(rhs, self)
85    }
86}
87
88impl<T> Mul<&T> for SparseTensor<T>
89where
90    T: Number,
91{
92    type Output = SparseTensor<T>;
93
94    fn mul(self, rhs: &T) -> Self::Output {
95        mul_scalar(*rhs, self)
96    }
97}
98
99// SparseTensor and SparseTensor
100
101impl<T> Mul<SparseTensor<T>> for SparseTensor<T>
102where
103    T: Number,
104{
105    type Output = SparseTensor<T>;
106
107    fn mul(self, rhs: SparseTensor<T>) -> Self::Output {
108        mul(self, &rhs)
109    }
110}
111
112impl<T> Mul<&SparseTensor<T>> for SparseTensor<T>
113where
114    T: Number,
115{
116    type Output = SparseTensor<T>;
117
118    fn mul(self, rhs: &SparseTensor<T>) -> Self::Output {
119        mul(self, rhs)
120    }
121}
122
123impl<T> Mul<SparseTensor<T>> for &SparseTensor<T>
124where
125    T: Number,
126{
127    type Output = SparseTensor<T>;
128
129    fn mul(self, rhs: SparseTensor<T>) -> Self::Output {
130        mul(rhs, self)
131    }
132}
133
134// MulAssign
135
136impl<T> MulAssign<SparseTensor<T>> for SparseTensor<T>
137where
138    T: Number,
139{
140    fn mul_assign(&mut self, rhs: SparseTensor<T>) {
141        *self = self as &Self * rhs;
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn mul_scalar() {
151        let mut a = SparseTensor::new(vec![3, 2, 2]);
152        a[&[0, 0, 0]] = 2.0;
153        a[&[0, 0, 1]] = 4.0;
154        a[&[1, 1, 0]] = 2.0;
155        a[&[1, 1, 1]] = 4.0;
156        a[&[2, 0, 0]] = 2.0;
157        a[&[2, 0, 1]] = 4.0;
158
159        let b = 2.0 * a.clone();
160        let c = a.clone() * 2.0;
161        let d = 2.0 * a.clone();
162        let e = a * &2.0;
163
164        assert_eq!(b, c);
165        assert_eq!(c, d);
166        assert_eq!(d, e);
167    }
168
169    #[test]
170    fn mul() {
171        let mut a = SparseTensor::new(vec![3, 2, 2]);
172        a[&[0, 0, 0]] = 2.0;
173        a[&[0, 0, 1]] = 4.0;
174        a[&[1, 1, 0]] = 2.0;
175        a[&[1, 1, 1]] = 4.0;
176        a[&[2, 0, 0]] = 2.0;
177        a[&[2, 0, 1]] = 4.0;
178
179        let mut b = SparseTensor::new(vec![3, 2, 2]);
180        b[&[0, 0, 0]] = 2.0;
181        b[&[0, 0, 1]] = 4.0;
182        b[&[1, 1, 0]] = 2.0;
183        b[&[1, 1, 1]] = 4.0;
184        b[&[2, 0, 0]] = 2.0;
185        b[&[2, 0, 1]] = 4.0;
186
187        let mut c = SparseTensor::new(vec![3, 2, 2]);
188        c[&[0, 0, 0]] = 4.0;
189        c[&[0, 0, 1]] = 16.0;
190        c[&[1, 1, 0]] = 4.0;
191        c[&[1, 1, 1]] = 16.0;
192        c[&[2, 0, 0]] = 4.0;
193        c[&[2, 0, 1]] = 16.0;
194
195        let d = a.clone() * b.clone();
196
197        let e = b * a;
198
199        assert_eq!(c, d);
200        assert_eq!(d, e);
201        assert_eq!(e, c);
202    }
203}