1use crate::csc::CscMatrix;
2use crate::csr::CsrMatrix;
3
4use crate::ops::Op;
5use crate::ops::serial::{
6 spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
7 spmm_csc_prealloc_unchecked, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc_unchecked,
8};
9use nalgebra::allocator::Allocator;
10use nalgebra::base::storage::RawStorage;
11use nalgebra::constraint::{DimEq, ShapeConstraint};
12use nalgebra::{
13 ClosedAddAssign, ClosedDivAssign, ClosedMulAssign, ClosedSubAssign, DefaultAllocator, Dim, Dyn,
14 Matrix, OMatrix, Scalar, U1,
15};
16use num_traits::{One, Zero};
17use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
18
19macro_rules! impl_bin_op {
22 ($trait:ident, $method:ident,
23 <$($life:lifetime),* $(,)? $($scalar_type:ident $(: $bounds:path)?)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
24 =>
25 {
26 impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
27 where
28 $($scalar_type: $($bounds + )? Scalar + ClosedAddAssign + ClosedSubAssign + ClosedMulAssign + Zero + One + Neg<Output=T>)?
32 {
33 type Output = $ret;
34 fn $method(self, $b: $b_type) -> Self::Output {
35 let $a = self;
36 $body
37 }
38 }
39 };
40}
41
42macro_rules! impl_sp_plus_minus {
45 ($matrix_type:ident, $spadd_fn:ident, +) => {
47 impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
48 };
49 ($matrix_type:ident, $spadd_fn:ident, -) => {
50 impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
51 };
52 ($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
53 impl_bin_op!($trait, $method,
54 <'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
55 let pattern = spadd_pattern(a.pattern(), b.pattern());
57 let values = vec![T::zero(); pattern.nnz()];
58 let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
60 .unwrap();
61 $spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
62 $spadd_fn(T::one(), &mut result, $factor, Op::NoOp(&b)).unwrap();
63 result
64 });
65
66 impl_bin_op!($trait, $method,
67 <'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
68 &a $sign b
69 });
70
71 impl_bin_op!($trait, $method,
72 <'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
73 a $sign &b
74 });
75 impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
76 a $sign &b
77 });
78 }
79}
80
81impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
82impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
83impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
84impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
85
86macro_rules! impl_mul {
87 ($($args:tt)*) => {
88 impl_bin_op!(Mul, mul, $($args)*);
89 }
90}
91
92macro_rules! impl_spmm {
95 ($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
96 impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
97 let pattern = $pattern_fn(a.pattern(), b.pattern());
98 let values = vec![T::zero(); pattern.nnz()];
99 let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
100 .unwrap();
101 $spmm_fn(T::zero(),
102 &mut result,
103 T::one(),
104 Op::NoOp(a),
105 Op::NoOp(b))
106 .expect("Internal error: spmm failed (please debug).");
107 result
108 });
109 impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
110 impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
111 impl_mul!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
112 }
113}
114
115impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc_unchecked);
116impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc_unchecked);
118
119macro_rules! impl_concrete_scalar_matrix_mul {
122 ($matrix_type:ident, $($scalar_type:ty),*) => {
123 $(
126 impl_mul!(<>(a: $scalar_type, b: $matrix_type<$scalar_type>)
127 -> $matrix_type<$scalar_type> { b * a });
128 impl_mul!(<'a>(a: $scalar_type, b: &'a $matrix_type<$scalar_type>)
129 -> $matrix_type<$scalar_type> { b * a });
130 impl_mul!(<'a>(a: &'a $scalar_type, b: $matrix_type<$scalar_type>)
131 -> $matrix_type<$scalar_type> { b * (*a) });
132 impl_mul!(<'a>(a: &'a $scalar_type, b: &'a $matrix_type<$scalar_type>)
133 -> $matrix_type<$scalar_type> { b * *a });
134 )*
135 }
136}
137
138macro_rules! impl_scalar_mul {
140 ($matrix_type: ident) => {
141 impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
142 let values: Vec<_> = a.values()
143 .iter()
144 .map(|v_i| v_i.clone() * b.clone())
145 .collect();
146 $matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
147 });
148 impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
149 a * &b
150 });
151 impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
152 let mut a = a;
153 for value in a.values_mut() {
154 *value = b.clone() * value.clone();
155 }
156 a
157 });
158 impl_mul!(<T>(a: $matrix_type<T>, b: T) -> $matrix_type<T> {
159 a * &b
160 });
161 impl_concrete_scalar_matrix_mul!(
162 $matrix_type,
163 i8, i16, i32, i64, isize, f32, f64);
164
165 impl<T> MulAssign<T> for $matrix_type<T>
166 where
167 T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One
168 {
169 fn mul_assign(&mut self, scalar: T) {
170 for val in self.values_mut() {
171 *val *= scalar.clone();
172 }
173 }
174 }
175
176 impl<'a, T> MulAssign<&'a T> for $matrix_type<T>
177 where
178 T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One
179 {
180 fn mul_assign(&mut self, scalar: &'a T) {
181 for val in self.values_mut() {
182 *val *= scalar.clone();
183 }
184 }
185 }
186 }
187}
188
189impl_scalar_mul!(CsrMatrix);
190impl_scalar_mul!(CscMatrix);
191
192macro_rules! impl_neg {
193 ($matrix_type:ident) => {
194 impl<T> Neg for $matrix_type<T>
195 where
196 T: Scalar + Neg<Output = T>,
197 {
198 type Output = $matrix_type<T>;
199
200 fn neg(mut self) -> Self::Output {
201 for v_i in self.values_mut() {
202 *v_i = -v_i.clone();
203 }
204 self
205 }
206 }
207
208 impl<'a, T> Neg for &'a $matrix_type<T>
209 where
210 T: Scalar + Neg<Output = T>,
211 {
212 type Output = $matrix_type<T>;
213
214 fn neg(self) -> Self::Output {
215 -self.clone()
220 }
221 }
222 };
223}
224
225impl_neg!(CsrMatrix);
226impl_neg!(CscMatrix);
227
228macro_rules! impl_div {
229 ($matrix_type:ident) => {
230 impl_bin_op!(Div, div, <T: ClosedDivAssign>(matrix: $matrix_type<T>, scalar: T) -> $matrix_type<T> {
231 let mut matrix = matrix;
232 matrix /= scalar;
233 matrix
234 });
235 impl_bin_op!(Div, div, <'a, T: ClosedDivAssign>(matrix: $matrix_type<T>, scalar: &T) -> $matrix_type<T> {
236 matrix / scalar.clone()
237 });
238 impl_bin_op!(Div, div, <'a, T: ClosedDivAssign>(matrix: &'a $matrix_type<T>, scalar: T) -> $matrix_type<T> {
239 let new_values = matrix.values()
240 .iter()
241 .map(|v_i| v_i.clone() / scalar.clone())
242 .collect();
243 $matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
244 .unwrap()
245 });
246 impl_bin_op!(Div, div, <'a, T: ClosedDivAssign>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
247 matrix / scalar.clone()
248 });
249
250 impl<T> DivAssign<T> for $matrix_type<T>
251 where T : Scalar + ClosedAddAssign + ClosedMulAssign + ClosedDivAssign + Zero + One
252 {
253 fn div_assign(&mut self, scalar: T) {
254 self.values_mut().iter_mut().for_each(|v_i| *v_i /= scalar.clone());
255 }
256 }
257
258 impl<'a, T> DivAssign<&'a T> for $matrix_type<T>
259 where T : Scalar + ClosedAddAssign + ClosedMulAssign + ClosedDivAssign + Zero + One
260 {
261 fn div_assign(&mut self, scalar: &'a T) {
262 *self /= scalar.clone();
263 }
264 }
265 }
266}
267
268impl_div!(CsrMatrix);
269impl_div!(CscMatrix);
270
271macro_rules! impl_spmm_cs_dense {
272 ($matrix_type_name:ident, $spmm_fn:ident) => {
273 impl_spmm_cs_dense!(&'a $matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
275 let (_, ncols) = rhs.shape_generic();
276 let nrows = Dyn(lhs.nrows());
277 let mut result = OMatrix::<T, Dyn, C>::zeros_generic(nrows, ncols);
278 $spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(lhs), Op::NoOp(rhs));
279 result
280 });
281
282 impl_spmm_cs_dense!(&'a $matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
284 lhs * &rhs
285 });
286 impl_spmm_cs_dense!($matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
287 &lhs * rhs
288 });
289 impl_spmm_cs_dense!($matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
290 &lhs * &rhs
291 });
292 };
293
294 ($sparse_matrix_type:ty, $dense_matrix_type:ty, $spmm_fn:ident,
297 |$lhs:ident, $rhs:ident| $body:tt) =>
298 {
299 impl<'a, T, R, C, S> Mul<$dense_matrix_type> for $sparse_matrix_type
300 where
301 T: Scalar + ClosedMulAssign + ClosedAddAssign + ClosedSubAssign + ClosedDivAssign + Neg + Zero + One,
302 R: Dim,
303 C: Dim,
304 S: RawStorage<T, R, C>,
305 DefaultAllocator: Allocator<Dyn, C>,
306 ShapeConstraint:
308 DimEq<U1, <<DefaultAllocator as Allocator<Dyn, C>>::Buffer<T> as RawStorage<T, Dyn, C>>::RStride>
310 + DimEq<C, Dyn>
311 + DimEq<Dyn, <<DefaultAllocator as Allocator<Dyn, C>>::Buffer<T> as RawStorage<T, Dyn, C>>::CStride>
312 + DimEq<U1, S::RStride>
314 + DimEq<R, Dyn>
315 + DimEq<Dyn, S::CStride>
316 {
317 type Output = OMatrix<T, Dyn, C>;
320
321 fn mul(self, rhs: $dense_matrix_type) -> Self::Output {
322 let $lhs = self;
323 let $rhs = rhs;
324 $body
325 }
326 }
327 }
328}
329
330impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
331impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);