opensrdk_linear_algebra/tensor/sparse/operators/
add.rs

1use crate::{
2    number::{c64, Number},
3    sparse::SparseTensor,
4};
5use rayon::prelude::*;
6use std::ops::{Add, AddAssign};
7
8fn add_scalar<T>(lhs: T, rhs: SparseTensor<T>) -> SparseTensor<T>
9where
10    T: Number,
11{
12    let mut rhs = rhs;
13
14    rhs.elems
15        .par_iter_mut()
16        .map(|r| {
17            *r.1 += lhs;
18        })
19        .collect::<Vec<_>>();
20
21    rhs
22}
23
24fn add<T>(lhs: SparseTensor<T>, rhs: &SparseTensor<T>) -> SparseTensor<T>
25where
26    T: Number,
27{
28    if !lhs.is_same_size(rhs) {
29        panic!("Dimension mismatch.")
30    }
31    let mut lhs = lhs;
32
33    rhs.elems.iter().for_each(|(k, v)| {
34        lhs[k] += *v;
35    });
36
37    lhs
38}
39
40// Scalar and SparseTensor
41
42macro_rules! impl_div_scalar {
43    {$t: ty} => {
44        impl Add<SparseTensor<$t>> for $t {
45            type Output = SparseTensor<$t>;
46
47            fn add(self, rhs: SparseTensor<$t>) -> Self::Output {
48                add_scalar(self, rhs)
49            }
50        }
51
52        impl Add<SparseTensor<$t>> for &$t {
53            type Output = SparseTensor<$t>;
54
55            fn add(self, rhs: SparseTensor<$t>) -> Self::Output {
56                add_scalar(*self, rhs)
57            }
58        }
59    }
60}
61
62impl_div_scalar! {f64}
63impl_div_scalar! {c64}
64
65// SparseTensor and Scalar
66
67impl<T> Add<T> for SparseTensor<T>
68where
69    T: Number,
70{
71    type Output = SparseTensor<T>;
72
73    fn add(self, rhs: T) -> Self::Output {
74        add_scalar(rhs, self)
75    }
76}
77
78impl<T> Add<&T> for SparseTensor<T>
79where
80    T: Number,
81{
82    type Output = SparseTensor<T>;
83
84    fn add(self, rhs: &T) -> Self::Output {
85        add_scalar(*rhs, self)
86    }
87}
88
89// SparseTensor and SparseTensor
90
91impl<T> Add<SparseTensor<T>> for SparseTensor<T>
92where
93    T: Number,
94{
95    type Output = SparseTensor<T>;
96
97    fn add(self, rhs: SparseTensor<T>) -> Self::Output {
98        add(self, &rhs)
99    }
100}
101
102impl<T> Add<&SparseTensor<T>> for SparseTensor<T>
103where
104    T: Number,
105{
106    type Output = SparseTensor<T>;
107
108    fn add(self, rhs: &SparseTensor<T>) -> Self::Output {
109        add(self, rhs)
110    }
111}
112
113impl<T> Add<SparseTensor<T>> for &SparseTensor<T>
114where
115    T: Number,
116{
117    type Output = SparseTensor<T>;
118
119    fn add(self, rhs: SparseTensor<T>) -> Self::Output {
120        add(rhs, self)
121    }
122}
123
124// AddAssign
125
126impl<T> AddAssign<SparseTensor<T>> for SparseTensor<T>
127where
128    T: Number,
129{
130    fn add_assign(&mut self, rhs: SparseTensor<T>) {
131        *self = self as &Self + rhs;
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138
139    #[test]
140    fn add_scalar() {
141        let mut a = SparseTensor::new(vec![3, 2, 2]);
142        a[&[0, 0, 0]] = 2.0;
143        a[&[0, 0, 1]] = 4.0;
144        a[&[1, 1, 0]] = 2.0;
145        a[&[1, 1, 1]] = 4.0;
146        a[&[2, 0, 0]] = 2.0;
147        a[&[2, 0, 1]] = 4.0;
148
149        let mut b = SparseTensor::new(vec![3, 2, 2]);
150        b[&[0, 0, 0]] = 4.0;
151        b[&[0, 0, 1]] = 6.0;
152        b[&[1, 1, 0]] = 4.0;
153        b[&[1, 1, 1]] = 6.0;
154        b[&[2, 0, 0]] = 4.0;
155        b[&[2, 0, 1]] = 6.0;
156
157        assert_eq!(a.clone() + 2.0, b);
158        assert_eq!(2.0 + a.clone(), b);
159        assert_eq!(&2.0 + a.clone(), b);
160        assert_eq!(a + &2.0, b);
161    }
162
163    #[test]
164    fn add() {
165        let mut a = SparseTensor::new(vec![3, 2, 2]);
166        a[&[0, 0, 0]] = 2.0;
167        a[&[0, 0, 1]] = 4.0;
168        a[&[1, 1, 0]] = 2.0;
169        a[&[1, 1, 1]] = 4.0;
170        a[&[2, 0, 0]] = 2.0;
171        a[&[2, 0, 1]] = 4.0;
172
173        let mut b = SparseTensor::new(vec![3, 2, 2]);
174        b[&[0, 0, 0]] = 4.0;
175        b[&[0, 0, 1]] = 6.0;
176        b[&[1, 1, 0]] = 4.0;
177        b[&[1, 1, 1]] = 6.0;
178        b[&[2, 0, 0]] = 4.0;
179        b[&[2, 0, 1]] = 6.0;
180
181        let mut c = SparseTensor::new(vec![3, 2, 2]);
182        c[&[0, 0, 0]] = 6.0;
183        c[&[0, 0, 1]] = 10.0;
184        c[&[1, 1, 0]] = 6.0;
185        c[&[1, 1, 1]] = 10.0;
186        c[&[2, 0, 0]] = 6.0;
187        c[&[2, 0, 1]] = 10.0;
188
189        assert_eq!(a.clone() + b.clone(), b.clone() + a.clone());
190        assert_eq!(a.clone() + b.clone(), c.clone());
191        assert_eq!(b.clone() + a.clone(), a.clone() + b.clone());
192        assert_eq!(a + b, c);
193    }
194}