opensrdk_linear_algebra/tensor/sparse/operators/
sub.rs

1use crate::{
2    number::{c64, Number},
3    sparse::SparseTensor,
4};
5use rayon::prelude::*;
6use std::ops::{Sub, SubAssign};
7
8fn sub_scalar<T>(lhs: T, rhs: SparseTensor<T>) -> SparseTensor<T>
9where
10    T: Number,
11{
12    let mut rhs = rhs;
13
14    rhs.elems
15        .par_iter_mut()
16        .map(|r| {
17            *r.1 -= lhs;
18        })
19        .collect::<Vec<_>>();
20
21    rhs
22}
23
24fn sub<T>(lhs: SparseTensor<T>, rhs: &SparseTensor<T>) -> SparseTensor<T>
25where
26    T: Number,
27{
28    if !lhs.is_same_size(rhs) {
29        panic!("Dimension mismatch.")
30    }
31    let mut lhs = lhs;
32
33    rhs.elems.iter().for_each(|(k, v)| {
34        lhs[k] -= *v;
35    });
36
37    lhs
38}
39
40// Scalar and SparseTensor
41
42macro_rules! impl_div_scalar {
43  {$t: ty} => {
44      impl Sub<SparseTensor<$t>> for $t {
45          type Output = SparseTensor<$t>;
46
47          fn sub(self, rhs: SparseTensor<$t>) -> Self::Output {
48              sub_scalar(self, rhs)
49          }
50      }
51
52      impl Sub<SparseTensor<$t>> for &$t {
53          type Output = SparseTensor<$t>;
54
55          fn sub(self, rhs: SparseTensor<$t>) -> Self::Output {
56              sub_scalar(*self, rhs)
57          }
58      }
59  }
60}
61
62impl_div_scalar! {f64}
63impl_div_scalar! {c64}
64
65// SparseTensor and Scalar
66
67impl<T> Sub<T> for SparseTensor<T>
68where
69    T: Number,
70{
71    type Output = SparseTensor<T>;
72
73    fn sub(self, rhs: T) -> Self::Output {
74        sub_scalar(rhs, self)
75    }
76}
77
78impl<T> Sub<&T> for SparseTensor<T>
79where
80    T: Number,
81{
82    type Output = SparseTensor<T>;
83
84    fn sub(self, rhs: &T) -> Self::Output {
85        sub_scalar(*rhs, self)
86    }
87}
88
89// SparseTensor and SparseTensor
90
91impl<T> Sub<SparseTensor<T>> for SparseTensor<T>
92where
93    T: Number,
94{
95    type Output = SparseTensor<T>;
96
97    fn sub(self, rhs: SparseTensor<T>) -> Self::Output {
98        sub(self, &rhs)
99    }
100}
101
102impl<T> Sub<&SparseTensor<T>> for SparseTensor<T>
103where
104    T: Number,
105{
106    type Output = SparseTensor<T>;
107
108    fn sub(self, rhs: &SparseTensor<T>) -> Self::Output {
109        sub(self, rhs)
110    }
111}
112
113impl<T> Sub<SparseTensor<T>> for &SparseTensor<T>
114where
115    T: Number,
116{
117    type Output = SparseTensor<T>;
118
119    fn sub(self, rhs: SparseTensor<T>) -> Self::Output {
120        -sub(rhs, self)
121    }
122}
123
124// SubAssign
125
126impl<T> SubAssign<SparseTensor<T>> for SparseTensor<T>
127where
128    T: Number,
129{
130    fn sub_assign(&mut self, rhs: SparseTensor<T>) {
131        *self = self as &Self - rhs;
132    }
133}
134
135#[cfg(test)]
136mod tests {
137    use std::{collections::HashMap, hash};
138
139    use super::*;
140
141    #[test]
142
143    fn sub_scalar() {
144        let mut a = SparseTensor::new(vec![3, 2, 2]);
145
146        a[&[0, 0, 0]] = 2.0;
147        a[&[0, 0, 1]] = 4.0;
148        a[&[1, 1, 0]] = 2.0;
149        a[&[1, 1, 1]] = 4.0;
150        a[&[2, 0, 0]] = 2.0;
151        a[&[2, 0, 1]] = 4.0;
152
153        let b = a - &2.0;
154        println!("{:?}", b);
155
156        // assert_eq!(b[&[0, 0, 0]], 0.0);
157        // assert_eq!(b[&[0, 0, 1]], 2.0);
158        // assert_eq!(b[&[1, 1, 0]], 0.0);
159        // assert_eq!(b[&[1, 1, 1]], 2.0);
160        // assert_eq!(b[&[2, 0, 0]], 0.0);
161        // assert_eq!(b[&[2, 0, 1]], 2.0);
162    }
163
164    #[test]
165
166    fn sub() {
167        let mut a = SparseTensor::new(vec![3, 2, 2]);
168
169        a[&[0, 0, 0]] = 2.0;
170        a[&[0, 0, 1]] = 4.0;
171        a[&[1, 1, 0]] = 2.0;
172        a[&[1, 1, 1]] = 4.0;
173        a[&[2, 0, 0]] = 2.0;
174        a[&[2, 0, 1]] = 4.0;
175
176        let mut b = SparseTensor::new(vec![3, 2, 2]);
177        b[&[0, 0, 0]] = 2.0;
178        b[&[0, 0, 1]] = 4.0;
179        b[&[1, 1, 0]] = 2.0;
180        b[&[1, 1, 1]] = 1.0;
181        b[&[2, 0, 0]] = 1.0;
182        b[&[2, 0, 1]] = 2.0;
183
184        let c = a - b;
185
186        assert_eq!(c[&[0, 0, 0]], 0.0);
187        assert_eq!(c[&[0, 0, 1]], 0.0);
188        assert_eq!(c[&[1, 1, 0]], 0.0);
189        assert_eq!(c[&[1, 1, 1]], 3.0);
190        assert_eq!(c[&[2, 0, 0]], 1.0);
191        assert_eq!(c[&[2, 0, 1]], 2.0);
192    }
193}