1use std::ops::{Add, Deref, Mul, Sub};
4
5use crate::errors::StructureError;
6use crate::indexing::SpIndex;
7use crate::sparse::compressed::SpMatView;
8use crate::sparse::csmat::CompressedStorage;
9use crate::sparse::prelude::*;
10use crate::sparse::vec::NnzEither::{Both, Left, Right};
11use crate::sparse::vec::SparseIterTools;
12use crate::IndPtr;
13use ndarray::{
14 self, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, ShapeBuilder,
15};
16use num_traits::Zero;
17
18use crate::Ix2;
19
20impl<
21 'a,
22 'b,
23 Lhs,
24 Rhs,
25 Res,
26 I,
27 Iptr,
28 IpStorage,
29 IStorage,
30 DStorage,
31 IpS2,
32 IS2,
33 DS2,
34 > Add<&'b CsMatBase<Rhs, I, IpS2, IS2, DS2, Iptr>>
35 for &'a CsMatBase<Lhs, I, IpStorage, IStorage, DStorage, Iptr>
36where
37 Lhs: Zero,
38 Rhs: Zero + Clone + Default,
39 Res: Zero + Clone,
40 for<'r> &'r Lhs: Add<&'r Rhs, Output = Res>,
41 I: 'a + SpIndex,
42 Iptr: 'a + SpIndex,
43 IpStorage: 'a + Deref<Target = [Iptr]>,
44 IStorage: 'a + Deref<Target = [I]>,
45 DStorage: 'a + Deref<Target = [Lhs]>,
46 IpS2: 'b + Deref<Target = [Iptr]>,
47 IS2: 'b + Deref<Target = [I]>,
48 DS2: 'b + Deref<Target = [Rhs]>,
49{
50 type Output = CsMatI<Res, I, Iptr>;
51
52 fn add(
53 self,
54 rhs: &'b CsMatBase<Rhs, I, IpS2, IS2, DS2, Iptr>,
55 ) -> Self::Output {
56 if self.storage() != rhs.view().storage() {
57 return csmat_binop(
58 self.view(),
59 rhs.to_other_storage().view(),
60 |x, y| x.add(y),
61 );
62 }
63 csmat_binop(self.view(), rhs.view(), |x, y| x.add(y))
64 }
65}
66
67impl<
68 'a,
69 'b,
70 Lhs,
71 Rhs,
72 Res,
73 I,
74 Iptr,
75 IpStorage,
76 IStorage,
77 DStorage,
78 IpS2,
79 IS2,
80 DS2,
81 > Sub<&'b CsMatBase<Rhs, I, IpS2, IS2, DS2, Iptr>>
82 for &'a CsMatBase<Lhs, I, IpStorage, IStorage, DStorage, Iptr>
83where
84 Lhs: Zero,
85 Rhs: Zero + Clone + Default,
86 Res: Zero + Clone,
87 for<'r> &'r Lhs: Sub<&'r Rhs, Output = Res>,
88 I: 'a + SpIndex,
89 Iptr: 'a + SpIndex,
90 IpStorage: 'a + Deref<Target = [Iptr]>,
91 IStorage: 'a + Deref<Target = [I]>,
92 DStorage: 'a + Deref<Target = [Lhs]>,
93 IpS2: 'a + Deref<Target = [Iptr]>,
94 IS2: 'a + Deref<Target = [I]>,
95 DS2: 'a + Deref<Target = [Rhs]>,
96{
97 type Output = CsMatI<Res, I, Iptr>;
98
99 fn sub(
100 self,
101 rhs: &'b CsMatBase<Rhs, I, IpS2, IS2, DS2, Iptr>,
102 ) -> Self::Output {
103 if self.storage() != rhs.view().storage() {
104 return csmat_binop(
105 self.view(),
106 rhs.to_other_storage().view(),
107 |x, y| x - y,
108 );
109 }
110 csmat_binop(self.view(), rhs.view(), |x, y| x - y)
111 }
112}
113
114pub fn mul_mat_same_storage<Lhs, Rhs, Res, I, Iptr, Mat1, Mat2>(
116 lhs: &Mat1,
117 rhs: &Mat2,
118) -> CsMatI<Res, I, Iptr>
119where
120 Lhs: Zero,
121 Rhs: Zero,
122 Res: Zero + Clone,
123 for<'r> &'r Lhs: std::ops::Mul<&'r Rhs, Output = Res>,
124 I: SpIndex,
125 Iptr: SpIndex,
126 Mat1: SpMatView<Lhs, I, Iptr>,
127 Mat2: SpMatView<Rhs, I, Iptr>,
128{
129 csmat_binop(lhs.view(), rhs.view(), |x, y| x * y)
130}
131
132macro_rules! sparse_scalar_mul {
133 ($scalar: ident) => {
134 impl<'a, I, Iptr, IpStorage, IStorage, DStorage> Mul<$scalar>
135 for &'a CsMatBase<$scalar, I, IpStorage, IStorage, DStorage, Iptr>
136 where
137 I: 'a + SpIndex,
138 Iptr: 'a + SpIndex,
139 IpStorage: 'a + Deref<Target = [Iptr]>,
140 IStorage: 'a + Deref<Target = [I]>,
141 DStorage: 'a + Deref<Target = [$scalar]>,
142 {
143 type Output = CsMatI<$scalar, I, Iptr>;
144
145 fn mul(self, rhs: $scalar) -> Self::Output {
146 self.map(|x| x * rhs)
147 }
148 }
149 };
150}
151
152sparse_scalar_mul!(u8);
153sparse_scalar_mul!(i8);
154sparse_scalar_mul!(u16);
155sparse_scalar_mul!(i16);
156sparse_scalar_mul!(u32);
157sparse_scalar_mul!(i32);
158sparse_scalar_mul!(u64);
159sparse_scalar_mul!(i64);
160sparse_scalar_mul!(isize);
161sparse_scalar_mul!(usize);
162sparse_scalar_mul!(f32);
163sparse_scalar_mul!(f64);
164
165pub fn csmat_binop<Lhs, Rhs, Res, I, Iptr, F>(
179 lhs: CsMatViewI<Lhs, I, Iptr>,
180 rhs: CsMatViewI<Rhs, I, Iptr>,
181 binop: F,
182) -> CsMatI<Res, I, Iptr>
183where
184 Lhs: Zero,
185 Rhs: Zero,
186 Res: Zero + Clone,
187 I: SpIndex,
188 Iptr: SpIndex,
189 F: Fn(&Lhs, &Rhs) -> Res,
190{
191 let nrows = lhs.rows();
192 let ncols = lhs.cols();
193 let storage = lhs.storage();
194 assert!(
195 nrows == rhs.rows() && ncols == rhs.cols(),
196 "Dimension mismatch"
197 );
198 assert_eq!(storage, rhs.storage(), "Storage mismatch");
199
200 let max_nnz = lhs.nnz() + rhs.nnz();
201 let mut out_indptr = vec![Iptr::zero(); lhs.outer_dims() + 1];
202 let mut out_indices = vec![I::zero(); max_nnz];
203 let mut out_data = vec![Res::zero(); max_nnz];
204
205 let nnz = csmat_binop_same_storage_raw(
206 lhs,
207 rhs,
208 binop,
209 &mut out_indptr[..],
210 &mut out_indices[..],
211 &mut out_data[..],
212 );
213 out_indices.truncate(nnz);
214 out_data.truncate(nnz);
215 CsMatI {
216 storage,
217 nrows,
218 ncols,
219 indptr: IndPtr::new_trusted(out_indptr),
220 indices: out_indices,
221 data: out_data,
222 }
223}
224
225pub fn csmat_binop_same_storage_raw<Lhs, Rhs, Res, I, Iptr, F>(
230 lhs: CsMatViewI<Lhs, I, Iptr>,
231 rhs: CsMatViewI<Rhs, I, Iptr>,
232 binop: F,
233 out_indptr: &mut [Iptr],
234 out_indices: &mut [I],
235 out_data: &mut [Res],
236) -> usize
237where
238 Lhs: Zero,
239 Rhs: Zero,
240 Res: Zero,
241 I: SpIndex,
242 Iptr: SpIndex,
243 F: Fn(&Lhs, &Rhs) -> Res,
244{
245 assert_eq!(lhs.cols(), rhs.cols());
246 assert_eq!(lhs.rows(), rhs.rows());
247 assert_eq!(lhs.storage(), rhs.storage());
248 assert_eq!(out_indptr.len(), rhs.outer_dims() + 1);
249 let max_nnz = lhs.nnz() + rhs.nnz();
250 assert!(out_data.len() >= max_nnz);
251 assert!(out_indices.len() >= max_nnz);
252 let mut nnz = 0;
253 out_indptr[0] = Iptr::zero();
254 let iter = lhs.outer_iterator().zip(rhs.outer_iterator()).enumerate();
255 for (dim, (lv, rv)) in iter {
256 for elem in lv.iter().nnz_or_zip(rv.iter()) {
257 let (ind, binop_val) = match elem {
258 Left((ind, val)) => (ind, binop(val, &Rhs::zero())),
259 Right((ind, val)) => (ind, binop(&Lhs::zero(), val)),
260 Both((ind, lval, rval)) => (ind, binop(lval, rval)),
261 };
262 if !binop_val.is_zero() {
263 out_indices[nnz] = I::from_usize_unchecked(ind);
264 out_data[nnz] = binop_val;
265 nnz += 1;
266 }
267 }
268 out_indptr[dim + 1] = Iptr::from_usize(nnz);
269 }
270 nnz
271}
272
273pub fn add_dense_mat_same_ordering<
280 Lhs,
281 Rhs,
282 Res,
283 Alpha,
284 Beta,
285 ByProd1,
286 ByProd2,
287 I,
288 Iptr,
289 Mat,
290 D,
291>(
292 lhs: &Mat,
293 rhs: &ArrayBase<D, Ix2>,
294 alpha: Alpha,
295 beta: Beta,
296) -> Array<Res, Ix2>
297where
298 Mat: SpMatView<Lhs, I, Iptr>,
299 D: ndarray::Data<Elem = Rhs>,
300 Lhs: Zero,
301 Rhs: Zero,
302 Res: Zero + Copy,
303 for<'r> &'r Alpha: Mul<&'r Lhs, Output = ByProd1>,
304 for<'r> &'r Beta: Mul<&'r Rhs, Output = ByProd2>,
305 ByProd1: Add<ByProd2, Output = Res>,
306 I: SpIndex,
307 Iptr: SpIndex,
308{
309 let shape = (rhs.shape()[0], rhs.shape()[1]);
310 let is_clike_layout = super::utils::fastest_axis(rhs.view()) == Axis(1);
311 let mut res = if is_clike_layout {
312 Array::zeros(shape)
313 } else {
314 Array::zeros(shape.f())
315 };
316 csmat_binop_dense_raw(
317 lhs.view(),
318 rhs.view(),
319 |x, y| &alpha * x + &beta * y,
320 res.view_mut(),
321 );
322 res
323}
324
325pub fn mul_dense_mat_same_ordering<
332 Lhs,
333 Rhs,
334 Res,
335 Alpha,
336 ByProd,
337 I,
338 Iptr,
339 Mat,
340 D,
341>(
342 lhs: &Mat,
343 rhs: &ArrayBase<D, Ix2>,
344 alpha: Alpha,
345) -> Array<Res, Ix2>
346where
347 Lhs: Zero,
348 Rhs: Zero,
349 Res: Zero + Clone,
350 Alpha: Copy + for<'r> Mul<&'r Lhs, Output = ByProd>,
351 ByProd: for<'r> Mul<&'r Rhs, Output = Res>,
352 I: SpIndex,
353 Iptr: SpIndex,
354 Mat: SpMatView<Lhs, I, Iptr>,
355 D: ndarray::Data<Elem = Rhs>,
356{
357 let shape = (rhs.shape()[0], rhs.shape()[1]);
358 let is_clike_layout = super::utils::fastest_axis(rhs.view()) == Axis(1);
359 let mut res = if is_clike_layout {
360 Array::zeros(shape)
361 } else {
362 Array::zeros(shape.f())
363 };
364 csmat_binop_dense_raw(
365 lhs.view(),
366 rhs.view(),
367 |x, y| alpha * x * y,
368 res.view_mut(),
369 );
370 res
371}
372
373pub fn csmat_binop_dense_raw<'a, Lhs, Rhs, Res, I, Iptr, F>(
385 lhs: CsMatViewI<'a, Lhs, I, Iptr>,
386 rhs: ArrayView<'a, Rhs, Ix2>,
387 binop: F,
388 mut out: ArrayViewMut<'a, Res, Ix2>,
389) where
390 Lhs: 'a + Zero,
391 Rhs: 'a + Zero,
392 Res: Zero,
393 I: 'a + SpIndex,
394 Iptr: 'a + SpIndex,
395 F: Fn(&Lhs, &Rhs) -> Res,
396{
397 if lhs.cols() != rhs.shape()[1]
398 || lhs.cols() != out.shape()[1]
399 || lhs.rows() != rhs.shape()[0]
400 || lhs.rows() != out.shape()[0]
401 {
402 panic!("Dimension mismatch");
403 }
404 match (
405 lhs.storage(),
406 super::utils::fastest_axis(rhs),
407 super::utils::fastest_axis(out.view()),
408 ) {
409 (CompressedStorage::CSR, Axis(1), Axis(1))
410 | (CompressedStorage::CSC, Axis(0), Axis(0)) => (),
411 (_, _, _) => panic!("Storage mismatch"),
412 }
413 let slowest_axis = super::utils::slowest_axis(rhs);
414 for ((mut orow, lrow), rrow) in out
415 .axis_iter_mut(slowest_axis)
416 .zip(lhs.outer_iterator())
417 .zip(rhs.axis_iter(slowest_axis))
418 {
419 for items in orow
421 .iter_mut()
422 .zip(rrow.iter().enumerate().nnz_or_zip(lrow.iter()))
423 {
424 let (oval, rl_elems) = items;
425 let binop_val = match rl_elems {
426 Left((_, val)) => binop(&Lhs::zero(), val),
427 Right((_, val)) => binop(val, &Rhs::zero()),
428 Both((_, rval, lval)) => binop(lval, rval),
429 };
430 *oval = binop_val;
431 }
432 }
433}
434
435pub fn csvec_binop<Lhs, Rhs, Res, I, F>(
443 mut lhs: CsVecViewI<Lhs, I>,
444 mut rhs: CsVecViewI<Rhs, I>,
445 binop: F,
446) -> Result<CsVecI<Res, I>, StructureError>
447where
448 Lhs: Zero,
449 Rhs: Zero,
450 F: Fn(&Lhs, &Rhs) -> Res,
451 I: SpIndex,
452{
453 csvec_fix_zeros(&mut lhs, &mut rhs);
454 assert_eq!(lhs.dim(), rhs.dim(), "Dimension mismatch");
455 let mut res = CsVecI::empty(lhs.dim());
456 let max_nnz = lhs.nnz() + rhs.nnz();
457 res.reserve_exact(max_nnz);
458 for elem in lhs.iter().nnz_or_zip(rhs.iter()) {
459 let (ind, binop_val) = match elem {
460 Left((ind, val)) => (ind, binop(val, &Rhs::zero())),
461 Right((ind, val)) => (ind, binop(&Lhs::zero(), val)),
462 Both((ind, lval, rval)) => (ind, binop(lval, rval)),
463 };
464 res.append(ind, binop_val);
465 }
466 Ok(res)
467}
468
469fn csvec_fix_zeros<Lhs, Rhs, I: SpIndex>(
470 lhs: &mut CsVecViewI<Lhs, I>,
471 rhs: &mut CsVecViewI<Rhs, I>,
472) {
473 if rhs.dim() == 0 {
474 rhs.dim = lhs.dim;
475 }
476 if lhs.dim() == 0 {
477 lhs.dim = rhs.dim;
478 }
479}
480
481#[cfg(test)]
482mod test {
483 use crate::sparse::CsMat;
484 use crate::sparse::CsVec;
485 use crate::test_data::{mat1, mat1_times_2, mat2, mat_dense1};
486 use ndarray::{arr2, Array};
487
488 fn mat1_plus_mat2() -> CsMat<f64> {
489 let indptr = vec![0, 5, 8, 9, 12, 15];
490 let indices = vec![0, 1, 2, 3, 4, 0, 3, 4, 2, 1, 2, 3, 1, 2, 3];
491 let data =
492 vec![6., 7., 6., 4., 3., 8., 11., 5., 5., 8., 2., 4., 4., 4., 7.];
493 CsMat::new((5, 5), indptr, indices, data)
494 }
495
496 fn mat1_minus_mat2() -> CsMat<f64> {
497 let indptr = vec![0, 4, 7, 8, 11, 14];
498 let indices = vec![0, 1, 3, 4, 0, 3, 4, 2, 1, 2, 3, 1, 2, 3];
499 let data = vec![
500 -6., -7., 4., -3., -8., -7., 5., 5., 8., -2., -4., -4., -4., 7.,
501 ];
502 CsMat::new((5, 5), indptr, indices, data)
503 }
504
505 fn mat1_times_mat2() -> CsMat<f64> {
506 let indptr = vec![0, 1, 2, 2, 2, 2];
507 let indices = vec![2, 3];
508 let data = vec![9., 18.];
509 CsMat::new((5, 5), indptr, indices, data)
510 }
511
512 #[test]
513 fn test_add1() {
514 let a = mat1();
515 let b = mat2();
516
517 let c = &a + &b;
518 let c_true = mat1_plus_mat2();
519 assert_eq!(c, c_true);
520
521 let a = CsMat::new((3, 3), vec![0, 1, 1, 2], vec![0, 2], vec![1., 1.]);
523 let b = CsMat::new((3, 3), vec![0, 1, 2, 2], vec![0, 1], vec![1., 1.]);
524 let c = CsMat::new(
525 (3, 3),
526 vec![0, 1, 2, 3],
527 vec![0, 1, 2],
528 vec![2., 1., 1.],
529 );
530
531 assert_eq!(c, &a + &b);
532 }
533
534 #[test]
535 fn test_sub1() {
536 let a = mat1();
537 let b = mat2();
538
539 let c = &a - &b;
540 let c_true = mat1_minus_mat2();
541 assert_eq!(c, c_true);
542 }
543
544 #[test]
545 fn test_mul1() {
546 let a = mat1();
547 let b = mat2();
548
549 let c = super::mul_mat_same_storage(&a, &b);
550 let c_true = mat1_times_mat2();
551 assert_eq!(c.indptr(), c_true.indptr());
552 assert_eq!(c.indices(), c_true.indices());
553 assert_eq!(c.data(), c_true.data());
554 }
555
556 #[test]
557 fn test_smul() {
558 let a = mat1();
559 let c = &a * 2.;
560 let c_true = mat1_times_2();
561 assert_eq!(c.indptr(), c_true.indptr());
562 assert_eq!(c.indices(), c_true.indices());
563 assert_eq!(c.data(), c_true.data());
564 }
565
566 #[test]
567 fn csvec_binops() {
568 let vec1 = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
569 let vec2 = CsVec::new(8, vec![1, 3, 5, 7], vec![2.; 4]);
570 let vec3 = CsVec::new(8, vec![1, 2, 5, 6], vec![3.; 4]);
571
572 let res = &vec1 + &vec2;
573 let expected_output = CsVec::new(
574 8,
575 vec![0, 1, 2, 3, 4, 5, 6, 7],
576 vec![1., 2., 1., 2., 1., 2., 1., 2.],
577 );
578 assert_eq!(expected_output, res);
579
580 let res = &vec1 + &vec3;
581 let expected_output =
582 CsVec::new(8, vec![0, 1, 2, 4, 5, 6], vec![1., 3., 4., 1., 3., 4.]);
583 assert_eq!(expected_output, res);
584 }
585
586 #[test]
587 fn zero_sized_vector_works_as_right_vector_operand() {
588 let vector = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
589 let zero = CsVec::<f64>::new(0, vec![], vec![]);
590 assert_eq!(&vector + zero, vector);
591 }
592
593 #[test]
594 fn zero_sized_vector_works_as_left_vector_operand() {
595 let vector = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
596 let zero = CsVec::<f64>::new(0, vec![], vec![]);
597 assert_eq!(zero + &vector, vector);
598 }
599
600 #[test]
601 fn csr_add_dense_rowmaj() {
602 let a = Array::<f32, ndarray::Dim<[usize; 2]>>::zeros((3, 3));
603 let b = CsMat::<f32>::eye(3);
604
605 let c = super::add_dense_mat_same_ordering(&b, &a, 1., 1.);
606
607 let mut expected_output = Array::zeros((3, 3));
608 expected_output[[0, 0]] = 1.;
609 expected_output[[1, 1]] = 1.;
610 expected_output[[2, 2]] = 1.;
611
612 assert_eq!(c, expected_output);
613
614 let a = mat1();
615 let b = mat_dense1();
616
617 let expected_output = arr2(&[
618 [0., 1., 5., 7., 4.],
619 [5., 6., 5., 6., 8.],
620 [4., 5., 9., 3., 2.],
621 [3., 12., 3., 2., 1.],
622 [1., 2., 1., 8., 0.],
623 ]);
624 let c = super::add_dense_mat_same_ordering(&a, &b, 1., 1.);
625 assert_eq!(c, expected_output);
626 let c = &a + &b;
627 assert_eq!(c, expected_output);
628 }
629
630 #[test]
631 fn csr_mul_dense_rowmaj() {
632 let a = Array::from_elem((3, 3), 1.);
633 let b = CsMat::<f64>::eye(3);
634
635 let c = super::mul_dense_mat_same_ordering(&b, &a, 1.);
636
637 let expected_output = Array::eye(3);
638
639 assert_eq!(c, expected_output);
640 }
641
642 #[test]
643 fn mul_dense_strided() {
644 let a = Array::from_elem((6, 6), 1.0);
647 let a = a.slice(ndarray::s![..;2, ..;2]);
648 let b = CsMat::<f64>::eye(3);
649
650 let c = super::mul_dense_mat_same_ordering(&b, &a, 1.0);
651 assert!(c.is_standard_layout());
652
653 let expected_output = Array::eye(3);
654 assert_eq!(c, expected_output);
655
656 use ndarray::ShapeBuilder;
657 let a = Array::from_elem((6, 6).f(), 1.0);
658 let a = a.slice(ndarray::s![..;2, ..;2]);
659 let b = CsMat::<f64>::eye_csc(3);
660
661 let c = super::mul_dense_mat_same_ordering(&b, &a, 1.0);
662 assert!(c.t().is_standard_layout());
663
664 let expected_output = Array::eye(3);
665 assert_eq!(c, expected_output);
666 }
667
668 #[test]
669 fn binop_standard_layouts() {
670 use ndarray::ShapeBuilder;
671 let csr = CsMat::zero((3, 4));
672 let a = Array::from_elem((3, 4), 1.0);
673 let mut out = a.clone();
674 super::csmat_binop_dense_raw(
675 csr.view(),
676 a.view(),
677 |_: &f32, _: &f32| 0.0,
678 out.view_mut(),
679 );
680
681 let csc = CsMat::zero((3, 4)).into_csc();
682 let a = Array::from_elem((3, 4).f(), 1.0);
683 let mut out = Array::zeros((3, 4).f());
684 super::csmat_binop_dense_raw(
685 csc.view(),
686 a.view(),
687 |_: &f32, _: &f32| 0.0,
688 out.view_mut(),
689 );
690 }
691
692 #[test]
693 fn binop_strided_layouts() {
694 use ndarray::{s, ShapeBuilder};
697 let csr = CsMat::zero((3, 4));
698 let a = Array::from_elem((3, 8), 1.0);
699 let a = a.slice(s![.., ..;2]);
700 let mut out = Array::zeros((3, 4));
701 super::csmat_binop_dense_raw(
702 csr.view(),
703 a.view(),
704 |_: &f32, _: &f32| 0.0,
705 out.view_mut(),
706 );
707
708 let csc = CsMat::zero((3, 4)).into_csc();
709 let a = Array::from_elem((3, 8).f(), 1.0);
710 let a = a.slice(s![.., ..;2]);
711 let mut out = Array::zeros((3, 4).f());
712 super::csmat_binop_dense_raw(
713 csc.view(),
714 a.view(),
715 |_: &f32, _: &f32| 0.0,
716 out.view_mut(),
717 );
718 }
719}