sprs_rssn/sparse/
binop.rs

1//! Sparse matrix addition, subtraction
2
3use std::ops::{Add, Deref, Mul, Sub};
4
5use crate::errors::StructureError;
6use crate::indexing::SpIndex;
7use crate::sparse::compressed::SpMatView;
8use crate::sparse::csmat::CompressedStorage;
9use crate::sparse::prelude::*;
10use crate::sparse::vec::NnzEither::{Both, Left, Right};
11use crate::sparse::vec::SparseIterTools;
12use crate::IndPtr;
13use ndarray::{
14    self, Array, ArrayBase, ArrayView, ArrayViewMut, Axis, ShapeBuilder,
15};
16use num_traits::Zero;
17
18use crate::Ix2;
19
20impl<
21        'a,
22        'b,
23        Lhs,
24        Rhs,
25        Res,
26        I,
27        Iptr,
28        IpStorage,
29        IStorage,
30        DStorage,
31        IpS2,
32        IS2,
33        DS2,
34    > Add<&'b CsMatBase<Rhs, I, IpS2, IS2, DS2, Iptr>>
35    for &'a CsMatBase<Lhs, I, IpStorage, IStorage, DStorage, Iptr>
36where
37    Lhs: Zero,
38    Rhs: Zero + Clone + Default,
39    Res: Zero + Clone,
40    for<'r> &'r Lhs: Add<&'r Rhs, Output = Res>,
41    I: 'a + SpIndex,
42    Iptr: 'a + SpIndex,
43    IpStorage: 'a + Deref<Target = [Iptr]>,
44    IStorage: 'a + Deref<Target = [I]>,
45    DStorage: 'a + Deref<Target = [Lhs]>,
46    IpS2: 'b + Deref<Target = [Iptr]>,
47    IS2: 'b + Deref<Target = [I]>,
48    DS2: 'b + Deref<Target = [Rhs]>,
49{
50    type Output = CsMatI<Res, I, Iptr>;
51
52    fn add(
53        self,
54        rhs: &'b CsMatBase<Rhs, I, IpS2, IS2, DS2, Iptr>,
55    ) -> Self::Output {
56        if self.storage() != rhs.view().storage() {
57            return csmat_binop(
58                self.view(),
59                rhs.to_other_storage().view(),
60                |x, y| x.add(y),
61            );
62        }
63        csmat_binop(self.view(), rhs.view(), |x, y| x.add(y))
64    }
65}
66
67impl<
68        'a,
69        'b,
70        Lhs,
71        Rhs,
72        Res,
73        I,
74        Iptr,
75        IpStorage,
76        IStorage,
77        DStorage,
78        IpS2,
79        IS2,
80        DS2,
81    > Sub<&'b CsMatBase<Rhs, I, IpS2, IS2, DS2, Iptr>>
82    for &'a CsMatBase<Lhs, I, IpStorage, IStorage, DStorage, Iptr>
83where
84    Lhs: Zero,
85    Rhs: Zero + Clone + Default,
86    Res: Zero + Clone,
87    for<'r> &'r Lhs: Sub<&'r Rhs, Output = Res>,
88    I: 'a + SpIndex,
89    Iptr: 'a + SpIndex,
90    IpStorage: 'a + Deref<Target = [Iptr]>,
91    IStorage: 'a + Deref<Target = [I]>,
92    DStorage: 'a + Deref<Target = [Lhs]>,
93    IpS2: 'a + Deref<Target = [Iptr]>,
94    IS2: 'a + Deref<Target = [I]>,
95    DS2: 'a + Deref<Target = [Rhs]>,
96{
97    type Output = CsMatI<Res, I, Iptr>;
98
99    fn sub(
100        self,
101        rhs: &'b CsMatBase<Rhs, I, IpS2, IS2, DS2, Iptr>,
102    ) -> Self::Output {
103        if self.storage() != rhs.view().storage() {
104            return csmat_binop(
105                self.view(),
106                rhs.to_other_storage().view(),
107                |x, y| x - y,
108            );
109        }
110        csmat_binop(self.view(), rhs.view(), |x, y| x - y)
111    }
112}
113
114/// Sparse matrix scalar multiplication, with same storage type
115pub fn mul_mat_same_storage<Lhs, Rhs, Res, I, Iptr, Mat1, Mat2>(
116    lhs: &Mat1,
117    rhs: &Mat2,
118) -> CsMatI<Res, I, Iptr>
119where
120    Lhs: Zero,
121    Rhs: Zero,
122    Res: Zero + Clone,
123    for<'r> &'r Lhs: std::ops::Mul<&'r Rhs, Output = Res>,
124    I: SpIndex,
125    Iptr: SpIndex,
126    Mat1: SpMatView<Lhs, I, Iptr>,
127    Mat2: SpMatView<Rhs, I, Iptr>,
128{
129    csmat_binop(lhs.view(), rhs.view(), |x, y| x * y)
130}
131
132macro_rules! sparse_scalar_mul {
133    ($scalar: ident) => {
134        impl<'a, I, Iptr, IpStorage, IStorage, DStorage> Mul<$scalar>
135            for &'a CsMatBase<$scalar, I, IpStorage, IStorage, DStorage, Iptr>
136        where
137            I: 'a + SpIndex,
138            Iptr: 'a + SpIndex,
139            IpStorage: 'a + Deref<Target = [Iptr]>,
140            IStorage: 'a + Deref<Target = [I]>,
141            DStorage: 'a + Deref<Target = [$scalar]>,
142        {
143            type Output = CsMatI<$scalar, I, Iptr>;
144
145            fn mul(self, rhs: $scalar) -> Self::Output {
146                self.map(|x| x * rhs)
147            }
148        }
149    };
150}
151
152sparse_scalar_mul!(u8);
153sparse_scalar_mul!(i8);
154sparse_scalar_mul!(u16);
155sparse_scalar_mul!(i16);
156sparse_scalar_mul!(u32);
157sparse_scalar_mul!(i32);
158sparse_scalar_mul!(u64);
159sparse_scalar_mul!(i64);
160sparse_scalar_mul!(isize);
161sparse_scalar_mul!(usize);
162sparse_scalar_mul!(f32);
163sparse_scalar_mul!(f64);
164
165/// Apply binary operation to two sparse matrices
166///
167/// Applies a binary operation to matching non-zero elements
168/// of two sparse matrices. When e.g. only the `lhs` has a non-zero at a
169/// given location, `0` is inferred for the non-zero value of the other matrix.
170/// Both matrices should have the same storage.
171///
172/// Thus the behaviour is correct iff `binop(N::zero(), N::zero()) == N::zero()`
173///
174/// # Panics
175///
176/// - on incompatible dimensions
177/// - on incomatible storage
178pub fn csmat_binop<Lhs, Rhs, Res, I, Iptr, F>(
179    lhs: CsMatViewI<Lhs, I, Iptr>,
180    rhs: CsMatViewI<Rhs, I, Iptr>,
181    binop: F,
182) -> CsMatI<Res, I, Iptr>
183where
184    Lhs: Zero,
185    Rhs: Zero,
186    Res: Zero + Clone,
187    I: SpIndex,
188    Iptr: SpIndex,
189    F: Fn(&Lhs, &Rhs) -> Res,
190{
191    let nrows = lhs.rows();
192    let ncols = lhs.cols();
193    let storage = lhs.storage();
194    assert!(
195        nrows == rhs.rows() && ncols == rhs.cols(),
196        "Dimension mismatch"
197    );
198    assert_eq!(storage, rhs.storage(), "Storage mismatch");
199
200    let max_nnz = lhs.nnz() + rhs.nnz();
201    let mut out_indptr = vec![Iptr::zero(); lhs.outer_dims() + 1];
202    let mut out_indices = vec![I::zero(); max_nnz];
203    let mut out_data = vec![Res::zero(); max_nnz];
204
205    let nnz = csmat_binop_same_storage_raw(
206        lhs,
207        rhs,
208        binop,
209        &mut out_indptr[..],
210        &mut out_indices[..],
211        &mut out_data[..],
212    );
213    out_indices.truncate(nnz);
214    out_data.truncate(nnz);
215    CsMatI {
216        storage,
217        nrows,
218        ncols,
219        indptr: IndPtr::new_trusted(out_indptr),
220        indices: out_indices,
221        data: out_data,
222    }
223}
224
225/// Raw implementation of scalar binary operation for compressed sparse matrices
226/// sharing the same storage. The output arrays are assumed to be preallocated
227///
228/// Returns the nnz count
229pub fn csmat_binop_same_storage_raw<Lhs, Rhs, Res, I, Iptr, F>(
230    lhs: CsMatViewI<Lhs, I, Iptr>,
231    rhs: CsMatViewI<Rhs, I, Iptr>,
232    binop: F,
233    out_indptr: &mut [Iptr],
234    out_indices: &mut [I],
235    out_data: &mut [Res],
236) -> usize
237where
238    Lhs: Zero,
239    Rhs: Zero,
240    Res: Zero,
241    I: SpIndex,
242    Iptr: SpIndex,
243    F: Fn(&Lhs, &Rhs) -> Res,
244{
245    assert_eq!(lhs.cols(), rhs.cols());
246    assert_eq!(lhs.rows(), rhs.rows());
247    assert_eq!(lhs.storage(), rhs.storage());
248    assert_eq!(out_indptr.len(), rhs.outer_dims() + 1);
249    let max_nnz = lhs.nnz() + rhs.nnz();
250    assert!(out_data.len() >= max_nnz);
251    assert!(out_indices.len() >= max_nnz);
252    let mut nnz = 0;
253    out_indptr[0] = Iptr::zero();
254    let iter = lhs.outer_iterator().zip(rhs.outer_iterator()).enumerate();
255    for (dim, (lv, rv)) in iter {
256        for elem in lv.iter().nnz_or_zip(rv.iter()) {
257            let (ind, binop_val) = match elem {
258                Left((ind, val)) => (ind, binop(val, &Rhs::zero())),
259                Right((ind, val)) => (ind, binop(&Lhs::zero(), val)),
260                Both((ind, lval, rval)) => (ind, binop(lval, rval)),
261            };
262            if !binop_val.is_zero() {
263                out_indices[nnz] = I::from_usize_unchecked(ind);
264                out_data[nnz] = binop_val;
265                nnz += 1;
266            }
267        }
268        out_indptr[dim + 1] = Iptr::from_usize(nnz);
269    }
270    nnz
271}
272
273/// Compute alpha * lhs + beta * rhs with lhs a sparse matrix and rhs dense
274/// and alpha and beta scalars
275///
276/// The matrices must have the same ordering, a `CSR` matrix must be
277/// added with a matrix with `C`-like ordering, a `CSC` matrix
278/// must be added with a matrix with `F`-like ordering.
279pub fn add_dense_mat_same_ordering<
280    Lhs,
281    Rhs,
282    Res,
283    Alpha,
284    Beta,
285    ByProd1,
286    ByProd2,
287    I,
288    Iptr,
289    Mat,
290    D,
291>(
292    lhs: &Mat,
293    rhs: &ArrayBase<D, Ix2>,
294    alpha: Alpha,
295    beta: Beta,
296) -> Array<Res, Ix2>
297where
298    Mat: SpMatView<Lhs, I, Iptr>,
299    D: ndarray::Data<Elem = Rhs>,
300    Lhs: Zero,
301    Rhs: Zero,
302    Res: Zero + Copy,
303    for<'r> &'r Alpha: Mul<&'r Lhs, Output = ByProd1>,
304    for<'r> &'r Beta: Mul<&'r Rhs, Output = ByProd2>,
305    ByProd1: Add<ByProd2, Output = Res>,
306    I: SpIndex,
307    Iptr: SpIndex,
308{
309    let shape = (rhs.shape()[0], rhs.shape()[1]);
310    let is_clike_layout = super::utils::fastest_axis(rhs.view()) == Axis(1);
311    let mut res = if is_clike_layout {
312        Array::zeros(shape)
313    } else {
314        Array::zeros(shape.f())
315    };
316    csmat_binop_dense_raw(
317        lhs.view(),
318        rhs.view(),
319        |x, y| &alpha * x + &beta * y,
320        res.view_mut(),
321    );
322    res
323}
324
325/// Compute coeff wise `alpha * lhs * rhs` with `lhs` a sparse matrix,
326/// `rhs` a dense matrix, and `alpha` a scalar
327///
328/// The matrices must have the same ordering, a `CSR` matrix must be
329/// multiplied with a matrix with `C`-like ordering, a `CSC` matrix
330/// must be multiplied with a matrix with `F`-like ordering.
331pub fn mul_dense_mat_same_ordering<
332    Lhs,
333    Rhs,
334    Res,
335    Alpha,
336    ByProd,
337    I,
338    Iptr,
339    Mat,
340    D,
341>(
342    lhs: &Mat,
343    rhs: &ArrayBase<D, Ix2>,
344    alpha: Alpha,
345) -> Array<Res, Ix2>
346where
347    Lhs: Zero,
348    Rhs: Zero,
349    Res: Zero + Clone,
350    Alpha: Copy + for<'r> Mul<&'r Lhs, Output = ByProd>,
351    ByProd: for<'r> Mul<&'r Rhs, Output = Res>,
352    I: SpIndex,
353    Iptr: SpIndex,
354    Mat: SpMatView<Lhs, I, Iptr>,
355    D: ndarray::Data<Elem = Rhs>,
356{
357    let shape = (rhs.shape()[0], rhs.shape()[1]);
358    let is_clike_layout = super::utils::fastest_axis(rhs.view()) == Axis(1);
359    let mut res = if is_clike_layout {
360        Array::zeros(shape)
361    } else {
362        Array::zeros(shape.f())
363    };
364    csmat_binop_dense_raw(
365        lhs.view(),
366        rhs.view(),
367        |x, y| alpha * x * y,
368        res.view_mut(),
369    );
370    res
371}
372
373/// Raw implementation of sparse/dense binary operations with the same
374/// ordering
375///
376/// # Panics
377///
378/// On dimension mismatch
379///
380/// On storage mismatch. The storage for the matrices must either be
381/// `lhs = CSR` with `rhs` and `out` with `Axis(1)` as the fastest dimension,
382/// or
383/// `lhs = CSC` with `rhs` and `out` with `Axis(0)` as the fastest dimension,
384pub fn csmat_binop_dense_raw<'a, Lhs, Rhs, Res, I, Iptr, F>(
385    lhs: CsMatViewI<'a, Lhs, I, Iptr>,
386    rhs: ArrayView<'a, Rhs, Ix2>,
387    binop: F,
388    mut out: ArrayViewMut<'a, Res, Ix2>,
389) where
390    Lhs: 'a + Zero,
391    Rhs: 'a + Zero,
392    Res: Zero,
393    I: 'a + SpIndex,
394    Iptr: 'a + SpIndex,
395    F: Fn(&Lhs, &Rhs) -> Res,
396{
397    if lhs.cols() != rhs.shape()[1]
398        || lhs.cols() != out.shape()[1]
399        || lhs.rows() != rhs.shape()[0]
400        || lhs.rows() != out.shape()[0]
401    {
402        panic!("Dimension mismatch");
403    }
404    match (
405        lhs.storage(),
406        super::utils::fastest_axis(rhs),
407        super::utils::fastest_axis(out.view()),
408    ) {
409        (CompressedStorage::CSR, Axis(1), Axis(1))
410        | (CompressedStorage::CSC, Axis(0), Axis(0)) => (),
411        (_, _, _) => panic!("Storage mismatch"),
412    }
413    let slowest_axis = super::utils::slowest_axis(rhs);
414    for ((mut orow, lrow), rrow) in out
415        .axis_iter_mut(slowest_axis)
416        .zip(lhs.outer_iterator())
417        .zip(rhs.axis_iter(slowest_axis))
418    {
419        // now some equivalent of nnz_or_zip is needed
420        for items in orow
421            .iter_mut()
422            .zip(rrow.iter().enumerate().nnz_or_zip(lrow.iter()))
423        {
424            let (oval, rl_elems) = items;
425            let binop_val = match rl_elems {
426                Left((_, val)) => binop(&Lhs::zero(), val),
427                Right((_, val)) => binop(val, &Rhs::zero()),
428                Both((_, rval, lval)) => binop(lval, rval),
429            };
430            *oval = binop_val;
431        }
432    }
433}
434
435/// Binary operations for [`CsVec`](CsVecBase)
436///
437/// This function iterates the non-zero locations of `lhs` and `rhs`
438/// and applies the function `binop` to the matching elements (defaulting
439/// to zero when e.g. only `lhs` has a non-zero at a given location).
440///
441/// The function thus has a correct behavior iff `binop(0, 0) == 0`.
442pub fn csvec_binop<Lhs, Rhs, Res, I, F>(
443    mut lhs: CsVecViewI<Lhs, I>,
444    mut rhs: CsVecViewI<Rhs, I>,
445    binop: F,
446) -> Result<CsVecI<Res, I>, StructureError>
447where
448    Lhs: Zero,
449    Rhs: Zero,
450    F: Fn(&Lhs, &Rhs) -> Res,
451    I: SpIndex,
452{
453    csvec_fix_zeros(&mut lhs, &mut rhs);
454    assert_eq!(lhs.dim(), rhs.dim(), "Dimension mismatch");
455    let mut res = CsVecI::empty(lhs.dim());
456    let max_nnz = lhs.nnz() + rhs.nnz();
457    res.reserve_exact(max_nnz);
458    for elem in lhs.iter().nnz_or_zip(rhs.iter()) {
459        let (ind, binop_val) = match elem {
460            Left((ind, val)) => (ind, binop(val, &Rhs::zero())),
461            Right((ind, val)) => (ind, binop(&Lhs::zero(), val)),
462            Both((ind, lval, rval)) => (ind, binop(lval, rval)),
463        };
464        res.append(ind, binop_val);
465    }
466    Ok(res)
467}
468
469fn csvec_fix_zeros<Lhs, Rhs, I: SpIndex>(
470    lhs: &mut CsVecViewI<Lhs, I>,
471    rhs: &mut CsVecViewI<Rhs, I>,
472) {
473    if rhs.dim() == 0 {
474        rhs.dim = lhs.dim;
475    }
476    if lhs.dim() == 0 {
477        lhs.dim = rhs.dim;
478    }
479}
480
481#[cfg(test)]
482mod test {
483    use crate::sparse::CsMat;
484    use crate::sparse::CsVec;
485    use crate::test_data::{mat1, mat1_times_2, mat2, mat_dense1};
486    use ndarray::{arr2, Array};
487
488    fn mat1_plus_mat2() -> CsMat<f64> {
489        let indptr = vec![0, 5, 8, 9, 12, 15];
490        let indices = vec![0, 1, 2, 3, 4, 0, 3, 4, 2, 1, 2, 3, 1, 2, 3];
491        let data =
492            vec![6., 7., 6., 4., 3., 8., 11., 5., 5., 8., 2., 4., 4., 4., 7.];
493        CsMat::new((5, 5), indptr, indices, data)
494    }
495
496    fn mat1_minus_mat2() -> CsMat<f64> {
497        let indptr = vec![0, 4, 7, 8, 11, 14];
498        let indices = vec![0, 1, 3, 4, 0, 3, 4, 2, 1, 2, 3, 1, 2, 3];
499        let data = vec![
500            -6., -7., 4., -3., -8., -7., 5., 5., 8., -2., -4., -4., -4., 7.,
501        ];
502        CsMat::new((5, 5), indptr, indices, data)
503    }
504
505    fn mat1_times_mat2() -> CsMat<f64> {
506        let indptr = vec![0, 1, 2, 2, 2, 2];
507        let indices = vec![2, 3];
508        let data = vec![9., 18.];
509        CsMat::new((5, 5), indptr, indices, data)
510    }
511
512    #[test]
513    fn test_add1() {
514        let a = mat1();
515        let b = mat2();
516
517        let c = &a + &b;
518        let c_true = mat1_plus_mat2();
519        assert_eq!(c, c_true);
520
521        // test with CSR matrices having differ row patterns
522        let a = CsMat::new((3, 3), vec![0, 1, 1, 2], vec![0, 2], vec![1., 1.]);
523        let b = CsMat::new((3, 3), vec![0, 1, 2, 2], vec![0, 1], vec![1., 1.]);
524        let c = CsMat::new(
525            (3, 3),
526            vec![0, 1, 2, 3],
527            vec![0, 1, 2],
528            vec![2., 1., 1.],
529        );
530
531        assert_eq!(c, &a + &b);
532    }
533
534    #[test]
535    fn test_sub1() {
536        let a = mat1();
537        let b = mat2();
538
539        let c = &a - &b;
540        let c_true = mat1_minus_mat2();
541        assert_eq!(c, c_true);
542    }
543
544    #[test]
545    fn test_mul1() {
546        let a = mat1();
547        let b = mat2();
548
549        let c = super::mul_mat_same_storage(&a, &b);
550        let c_true = mat1_times_mat2();
551        assert_eq!(c.indptr(), c_true.indptr());
552        assert_eq!(c.indices(), c_true.indices());
553        assert_eq!(c.data(), c_true.data());
554    }
555
556    #[test]
557    fn test_smul() {
558        let a = mat1();
559        let c = &a * 2.;
560        let c_true = mat1_times_2();
561        assert_eq!(c.indptr(), c_true.indptr());
562        assert_eq!(c.indices(), c_true.indices());
563        assert_eq!(c.data(), c_true.data());
564    }
565
566    #[test]
567    fn csvec_binops() {
568        let vec1 = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
569        let vec2 = CsVec::new(8, vec![1, 3, 5, 7], vec![2.; 4]);
570        let vec3 = CsVec::new(8, vec![1, 2, 5, 6], vec![3.; 4]);
571
572        let res = &vec1 + &vec2;
573        let expected_output = CsVec::new(
574            8,
575            vec![0, 1, 2, 3, 4, 5, 6, 7],
576            vec![1., 2., 1., 2., 1., 2., 1., 2.],
577        );
578        assert_eq!(expected_output, res);
579
580        let res = &vec1 + &vec3;
581        let expected_output =
582            CsVec::new(8, vec![0, 1, 2, 4, 5, 6], vec![1., 3., 4., 1., 3., 4.]);
583        assert_eq!(expected_output, res);
584    }
585
586    #[test]
587    fn zero_sized_vector_works_as_right_vector_operand() {
588        let vector = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
589        let zero = CsVec::<f64>::new(0, vec![], vec![]);
590        assert_eq!(&vector + zero, vector);
591    }
592
593    #[test]
594    fn zero_sized_vector_works_as_left_vector_operand() {
595        let vector = CsVec::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
596        let zero = CsVec::<f64>::new(0, vec![], vec![]);
597        assert_eq!(zero + &vector, vector);
598    }
599
600    #[test]
601    fn csr_add_dense_rowmaj() {
602        let a = Array::<f32, ndarray::Dim<[usize; 2]>>::zeros((3, 3));
603        let b = CsMat::<f32>::eye(3);
604
605        let c = super::add_dense_mat_same_ordering(&b, &a, 1., 1.);
606
607        let mut expected_output = Array::zeros((3, 3));
608        expected_output[[0, 0]] = 1.;
609        expected_output[[1, 1]] = 1.;
610        expected_output[[2, 2]] = 1.;
611
612        assert_eq!(c, expected_output);
613
614        let a = mat1();
615        let b = mat_dense1();
616
617        let expected_output = arr2(&[
618            [0., 1., 5., 7., 4.],
619            [5., 6., 5., 6., 8.],
620            [4., 5., 9., 3., 2.],
621            [3., 12., 3., 2., 1.],
622            [1., 2., 1., 8., 0.],
623        ]);
624        let c = super::add_dense_mat_same_ordering(&a, &b, 1., 1.);
625        assert_eq!(c, expected_output);
626        let c = &a + &b;
627        assert_eq!(c, expected_output);
628    }
629
630    #[test]
631    fn csr_mul_dense_rowmaj() {
632        let a = Array::from_elem((3, 3), 1.);
633        let b = CsMat::<f64>::eye(3);
634
635        let c = super::mul_dense_mat_same_ordering(&b, &a, 1.);
636
637        let expected_output = Array::eye(3);
638
639        assert_eq!(c, expected_output);
640    }
641
642    #[test]
643    fn mul_dense_strided() {
644        // Multiplication should yield dense matrices
645        // with the same fastest axis as input
646        let a = Array::from_elem((6, 6), 1.0);
647        let a = a.slice(ndarray::s![..;2, ..;2]);
648        let b = CsMat::<f64>::eye(3);
649
650        let c = super::mul_dense_mat_same_ordering(&b, &a, 1.0);
651        assert!(c.is_standard_layout());
652
653        let expected_output = Array::eye(3);
654        assert_eq!(c, expected_output);
655
656        use ndarray::ShapeBuilder;
657        let a = Array::from_elem((6, 6).f(), 1.0);
658        let a = a.slice(ndarray::s![..;2, ..;2]);
659        let b = CsMat::<f64>::eye_csc(3);
660
661        let c = super::mul_dense_mat_same_ordering(&b, &a, 1.0);
662        assert!(c.t().is_standard_layout());
663
664        let expected_output = Array::eye(3);
665        assert_eq!(c, expected_output);
666    }
667
668    #[test]
669    fn binop_standard_layouts() {
670        use ndarray::ShapeBuilder;
671        let csr = CsMat::zero((3, 4));
672        let a = Array::from_elem((3, 4), 1.0);
673        let mut out = a.clone();
674        super::csmat_binop_dense_raw(
675            csr.view(),
676            a.view(),
677            |_: &f32, _: &f32| 0.0,
678            out.view_mut(),
679        );
680
681        let csc = CsMat::zero((3, 4)).into_csc();
682        let a = Array::from_elem((3, 4).f(), 1.0);
683        let mut out = Array::zeros((3, 4).f());
684        super::csmat_binop_dense_raw(
685            csc.view(),
686            a.view(),
687            |_: &f32, _: &f32| 0.0,
688            out.view_mut(),
689        );
690    }
691
692    #[test]
693    fn binop_strided_layouts() {
694        // Strided matrices are compatible if they have
695        // the same fastest dimension
696        use ndarray::{s, ShapeBuilder};
697        let csr = CsMat::zero((3, 4));
698        let a = Array::from_elem((3, 8), 1.0);
699        let a = a.slice(s![.., ..;2]);
700        let mut out = Array::zeros((3, 4));
701        super::csmat_binop_dense_raw(
702            csr.view(),
703            a.view(),
704            |_: &f32, _: &f32| 0.0,
705            out.view_mut(),
706        );
707
708        let csc = CsMat::zero((3, 4)).into_csc();
709        let a = Array::from_elem((3, 8).f(), 1.0);
710        let a = a.slice(s![.., ..;2]);
711        let mut out = Array::zeros((3, 4).f());
712        super::csmat_binop_dense_raw(
713            csc.view(),
714            a.view(),
715            |_: &f32, _: &f32| 0.0,
716            out.view_mut(),
717        );
718    }
719}