sprs/sparse/
prod.rs

1//! Sparse matrix product
2
3use crate::dense_vector::{DenseVector, DenseVectorMut};
4use crate::indexing::SpIndex;
5use crate::sparse::compressed::SpMatView;
6use crate::sparse::prelude::*;
7use crate::Ix2;
8use ndarray::{ArrayView, ArrayViewMut, Axis};
9use num_traits::Num;
10
11/// Compute the dot product of two sparse vectors, using binary search to find matching indices.
12///
13/// Runs in O(MlogN) time, where M and N are the number of non-zero entries in each vector.
14pub fn csvec_dot_by_binary_search<N, I, A, B>(
15    vec1: CsVecViewI<A, I>,
16    vec2: CsVecViewI<B, I>,
17) -> N
18where
19    I: SpIndex,
20    N: crate::MulAcc<A, B> + num_traits::Zero,
21{
22    // Check vec1.nnz<vec2.nnz
23    // Reverse the dot product vec1 and vec2, but preserve possibly non-commutative MulAcc
24    // through a lamba.
25    if vec1.nnz() > vec2.nnz() {
26        csvec_dot_by_binary_search_impl(vec2, vec1, |acc: &mut N, a, b| {
27            acc.mul_acc(b, a);
28        })
29    } else {
30        csvec_dot_by_binary_search_impl(vec1, vec2, |acc: &mut N, a, b| {
31            acc.mul_acc(a, b);
32        })
33    }
34}
35
36/// Inner routine of `csvec_dot_by_binary_search`, removes need for commutative `MulAcc`
37pub(crate) fn csvec_dot_by_binary_search_impl<N, I, A, B, F>(
38    vec1: CsVecViewI<A, I>,
39    vec2: CsVecViewI<B, I>,
40    mul_acc: F,
41) -> N
42where
43    F: Fn(&mut N, &A, &B),
44    I: SpIndex,
45    N: num_traits::Zero,
46{
47    assert!(vec1.nnz() <= vec2.nnz());
48    // vec1.nnz is smaller
49    let (mut idx1, mut val1, mut idx2, mut val2) =
50        (vec1.indices(), vec1.data(), vec2.indices(), vec2.data());
51
52    let mut sum = N::zero();
53    while !idx1.is_empty() && !idx2.is_empty() {
54        debug_assert_eq!(idx1.len(), val1.len());
55        debug_assert_eq!(idx2.len(), val2.len());
56
57        let (found, i) = match idx2.binary_search(&idx1[0]) {
58            Ok(i) => (true, i),
59            Err(i) => (false, i),
60        };
61        if found {
62            mul_acc(&mut sum, &val1[0], &val2[i]);
63        }
64        idx1 = &idx1[1..];
65        val1 = &val1[1..];
66        idx2 = &idx2[i..];
67        val2 = &val2[i..];
68    }
69    sum
70}
71
72/// Multiply a sparse CSC matrix with a dense vector and accumulate the result
73/// into another dense vector
74pub fn mul_acc_mat_vec_csc<N, I, Iptr, V, VRes, A, B>(
75    mat: CsMatViewI<A, I, Iptr>,
76    in_vec: V,
77    mut res_vec: VRes,
78) where
79    N: crate::MulAcc<A, B>,
80    I: SpIndex,
81    Iptr: SpIndex,
82    V: DenseVector<Scalar = B>,
83    VRes: DenseVectorMut<Scalar = N>,
84{
85    let mat = mat.view();
86    assert!(
87        mat.cols() == in_vec.dim() && mat.rows() == res_vec.dim(),
88        "Dimension mismatch"
89    );
90    assert!(mat.is_csc(), "Storage mismatch");
91
92    for (col_ind, vec) in mat.outer_iterator().enumerate() {
93        let vec_elem = in_vec.index(col_ind);
94        for (row_ind, mtx_elem) in vec.iter() {
95            // TODO: unsafe access to value? needs bench
96            res_vec.index_mut(row_ind).mul_acc(mtx_elem, vec_elem);
97        }
98    }
99}
100
101/// Multiply a sparse CSR matrix with a dense vector and accumulate the result
102/// into another dense vector
103pub fn mul_acc_mat_vec_csr<N, A, B, I, Iptr, V, VRes>(
104    mat: CsMatViewI<A, I, Iptr>,
105    in_vec: V,
106    mut res_vec: VRes,
107) where
108    N: crate::MulAcc<A, B>,
109    I: SpIndex,
110    Iptr: SpIndex,
111    V: DenseVector<Scalar = B>,
112    VRes: DenseVectorMut<Scalar = N>,
113{
114    assert!(
115        mat.cols() == in_vec.dim() && mat.rows() == res_vec.dim(),
116        "Dimension mismatch"
117    );
118    assert!(mat.is_csr(), "Storage mismatch");
119
120    for (row_ind, vec) in mat.outer_iterator().enumerate() {
121        let tv = res_vec.index_mut(row_ind);
122        for (col_ind, mtx_elem) in vec.iter() {
123            // TODO: unsafe access to value? needs bench
124            tv.mul_acc(mtx_elem, in_vec.index(col_ind));
125        }
126    }
127}
128
129/// Allocate the appropriate workspace for a CSR-CSR product
130pub fn workspace_csr<N, A, B, I, Iptr, Mat1, Mat2>(
131    _: &Mat1,
132    rhs: &Mat2,
133) -> Vec<N>
134where
135    N: Clone + Num,
136    I: SpIndex,
137    Iptr: SpIndex,
138    Mat1: SpMatView<A, I, Iptr>,
139    Mat2: SpMatView<B, I, Iptr>,
140{
141    let len = rhs.view().cols();
142    vec![N::zero(); len]
143}
144
145/// Allocate the appropriate workspace for a CSC-CSC product
146pub fn workspace_csc<N, A, B, I, Iptr, Mat1, Mat2>(
147    lhs: &Mat1,
148    _: &Mat2,
149) -> Vec<N>
150where
151    N: Clone + Num,
152    I: SpIndex,
153    Iptr: SpIndex,
154    Mat1: SpMatView<A, I, Iptr>,
155    Mat2: SpMatView<B, I, Iptr>,
156{
157    let len = lhs.view().rows();
158    vec![N::zero(); len]
159}
160
161/// CSR-vector multiplication
162pub fn csr_mul_csvec<N, A, B, I, Iptr>(
163    lhs: CsMatViewI<A, I, Iptr>,
164    rhs: CsVecViewI<B, I>,
165) -> CsVecI<N, I>
166where
167    N: crate::MulAcc<A, B> + num_traits::Zero + PartialEq + Clone,
168    I: SpIndex,
169    Iptr: SpIndex,
170{
171    if rhs.dim == 0 {
172        // create an empty sparse vector of correct dimension
173        return CsVecI::empty(0);
174    }
175    assert_eq!(lhs.cols(), rhs.dim(), "Dimension mismatch");
176    let mut res = CsVecI::empty(lhs.rows());
177    for (row_ind, lvec) in lhs.outer_iterator().enumerate() {
178        let val = lvec.dot_acc(&rhs);
179        if val != N::zero() {
180            res.append(row_ind, val);
181        }
182    }
183    res
184}
185
186/// CSR-dense rowmaj multiplication
187///
188/// Performs better if rhs has a decent number of colums.
189pub fn csr_mulacc_dense_rowmaj<'a, N, A, B, I, Iptr>(
190    lhs: CsMatViewI<A, I, Iptr>,
191    rhs: ArrayView<B, Ix2>,
192    mut out: ArrayViewMut<'a, N, Ix2>,
193) where
194    N: 'a + crate::MulAcc<A, B>,
195    I: 'a + SpIndex,
196    Iptr: 'a + SpIndex,
197{
198    assert_eq!(lhs.cols(), rhs.shape()[0], "Dimension mismatch");
199    assert_eq!(lhs.rows(), out.shape()[0], "Dimension mismatch");
200    assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
201    assert!(lhs.is_csr(), "Storage mismatch");
202
203    let axis0 = Axis(0);
204    for (line, mut oline) in lhs.outer_iterator().zip(out.axis_iter_mut(axis0))
205    {
206        for (col_ind, lval) in line.iter() {
207            let rline = rhs.row(col_ind);
208            // TODO: call an axpy primitive to benefit from vectorisation?
209            for (oval, rval) in oline.iter_mut().zip(rline.iter()) {
210                oval.mul_acc(lval, rval);
211            }
212        }
213    }
214}
215
216/// CSC-dense rowmaj multiplication
217///
218/// Performs better if rhs has a decent number of colums.
219pub fn csc_mulacc_dense_rowmaj<'a, N, A, B, I, Iptr>(
220    lhs: CsMatViewI<A, I, Iptr>,
221    rhs: ArrayView<B, Ix2>,
222    mut out: ArrayViewMut<'a, N, Ix2>,
223) where
224    N: 'a + crate::MulAcc<A, B>,
225    I: 'a + SpIndex,
226    Iptr: 'a + SpIndex,
227{
228    assert_eq!(lhs.cols(), rhs.shape()[0], "Dimension mismatch");
229    assert_eq!(lhs.rows(), out.shape()[0], "Dimension mismatch");
230    assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
231    assert!(lhs.is_csc(), "Storage mismatch");
232
233    for (lcol, rline) in lhs.outer_iterator().zip(rhs.outer_iter()) {
234        for (orow, lval) in lcol.iter() {
235            let mut oline = out.row_mut(orow);
236            for (oval, rval) in oline.iter_mut().zip(rline.iter()) {
237                oval.mul_acc(lval, rval);
238            }
239        }
240    }
241}
242
243/// CSC-dense colmaj multiplication
244///
245/// Performs better if rhs has few columns.
246pub fn csc_mulacc_dense_colmaj<'a, N, A, B, I, Iptr>(
247    lhs: CsMatViewI<A, I, Iptr>,
248    rhs: ArrayView<B, Ix2>,
249    mut out: ArrayViewMut<'a, N, Ix2>,
250) where
251    N: 'a + crate::MulAcc<A, B>,
252    I: 'a + SpIndex,
253    Iptr: 'a + SpIndex,
254{
255    assert_eq!(lhs.cols(), rhs.shape()[0], "Dimension mismatch");
256    assert_eq!(lhs.rows(), out.shape()[0], "Dimension mismatch");
257    assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
258    assert!(lhs.is_csc(), "Storage mismatch");
259
260    let axis1 = Axis(1);
261    for (mut ocol, rcol) in out.axis_iter_mut(axis1).zip(rhs.axis_iter(axis1)) {
262        for (rrow, lcol) in lhs.outer_iterator().enumerate() {
263            let rval = &rcol[[rrow]];
264            for (orow, lval) in lcol.iter() {
265                ocol[[orow]].mul_acc(lval, rval);
266            }
267        }
268    }
269}
270
271/// CSR-dense colmaj multiplication
272///
273/// Performs better if rhs has few columns.
274pub fn csr_mulacc_dense_colmaj<'a, N, A, B, I, Iptr>(
275    lhs: CsMatViewI<A, I, Iptr>,
276    rhs: ArrayView<B, Ix2>,
277    mut out: ArrayViewMut<'a, N, Ix2>,
278) where
279    N: 'a + crate::MulAcc<A, B>,
280    I: 'a + SpIndex,
281    Iptr: 'a + SpIndex,
282{
283    assert_eq!(lhs.cols(), rhs.shape()[0], "Dimension mismatch");
284    assert_eq!(lhs.rows(), out.shape()[0], "Dimension mismatch");
285    assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
286    assert!(lhs.is_csr(), "Storage mismatch");
287
288    let axis1 = Axis(1);
289    for (mut ocol, rcol) in out.axis_iter_mut(axis1).zip(rhs.axis_iter(axis1)) {
290        for (orow, lrow) in lhs.outer_iterator().enumerate() {
291            let oval = &mut ocol[[orow]];
292            for (rrow, lval) in lrow.iter() {
293                let rval = &rcol[[rrow]];
294                oval.mul_acc(lval, rval);
295            }
296        }
297    }
298}
299
300#[cfg(test)]
301mod test {
302    use super::*;
303    use crate::sparse::{CsMat, CsMatView, CsVec};
304    use crate::test_data::{
305        mat1, mat1_csc, mat1_csc_matprod_mat4, mat1_matprod_mat2,
306        mat1_self_matprod, mat2, mat4, mat5, mat_dense1, mat_dense1_colmaj,
307        mat_dense2,
308    };
309    use ndarray::linalg::Dot;
310    use ndarray::{arr1, arr2, s, Array, Array2, Dimension, ShapeBuilder};
311
312    #[test]
313    fn test_csvec_dot_by_binary_search() {
314        let vec1 = CsVecI::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
315        let vec2 = CsVecI::new(8, vec![1, 3, 5, 7], vec![2.; 4]);
316        let vec3 = CsVecI::new(8, vec![1, 2, 5, 6], vec![3.; 4]);
317
318        assert_eq!(0., csvec_dot_by_binary_search(vec1.view(), vec2.view()));
319        assert_eq!(4., csvec_dot_by_binary_search(vec1.view(), vec1.view()));
320        assert_eq!(16., csvec_dot_by_binary_search(vec2.view(), vec2.view()));
321        assert_eq!(6., csvec_dot_by_binary_search(vec1.view(), vec3.view()));
322        assert_eq!(12., csvec_dot_by_binary_search(vec2.view(), vec3.view()));
323    }
324
325    #[test]
326    fn mul_csc_vec() {
327        let indptr: &[usize] = &[0, 2, 4, 5, 6, 7];
328        let indices: &[usize] = &[2, 3, 3, 4, 2, 1, 3];
329        let data: &[f64] = &[
330            0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123,
331            0.88132896, 0.72527863,
332        ];
333
334        let mat = CsMatView::new_csc((5, 5), indptr, indices, data);
335        let vector = vec![0.1, 0.2, -0.1, 0.3, 0.9];
336        let mut res_vec = vec![0., 0., 0., 0., 0.];
337        mul_acc_mat_vec_csc(mat, &vector, &mut res_vec);
338
339        let expected_output =
340            vec![0., 0.26439869, -0.01803924, 0.75120319, 0.11616419];
341
342        let epsilon = 1e-7; // TODO: get better values and increase precision
343
344        assert!(res_vec
345            .iter()
346            .zip(expected_output.iter())
347            .all(|(x, y)| (*x - *y).abs() < epsilon));
348    }
349
350    #[test]
351    fn mul_csc_vec_ndarray() {
352        let indptr: &[usize] = &[0, 2, 4, 5, 6, 7];
353        let indices: &[usize] = &[2, 3, 3, 4, 2, 1, 3];
354        let data: &[f64] = &[
355            0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123,
356            0.88132896, 0.72527863,
357        ];
358
359        let mat = CsMatView::new_csc((5, 5), indptr, indices, data);
360        let vector = arr1(&[0.1f64, 0.2, -0.1, 0.3, 0.9]);
361        let mut res_vec = Array::<f64, _>::zeros(5);
362        mul_acc_mat_vec_csc(mat, vector, res_vec.view_mut());
363
364        let expected_output =
365            vec![0., 0.26439869, -0.01803924, 0.75120319, 0.11616419];
366
367        let epsilon = 1e-7; // TODO: get better values and increase precision
368
369        assert!(res_vec
370            .iter()
371            .zip(expected_output.iter())
372            .all(|(x, y)| (*x - *y).abs() < epsilon));
373    }
374
375    #[test]
376    fn mul_csr_vec() {
377        let indptr: &[usize] = &[0, 3, 3, 5, 6, 7];
378        let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4];
379        let data: &[f64] = &[
380            0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315,
381            0.39244208, 0.57202407,
382        ];
383
384        let mat = CsMatView::new((5, 5), indptr, indices, data);
385        let slice: &[f64] = &[0.1, 0.2, -0.1, 0.3, 0.9];
386        let mut res_vec = vec![0., 0., 0., 0., 0.];
387        mul_acc_mat_vec_csr(mat, slice, &mut res_vec);
388
389        let expected_output =
390            vec![0.22527496, 0., 0.17814121, 0.35319787, 0.51482166];
391
392        let epsilon = 1e-7; // TODO: get better values and increase precision
393
394        assert!(res_vec
395            .iter()
396            .zip(expected_output.iter())
397            .all(|(x, y)| (*x - *y).abs() < epsilon));
398    }
399
400    #[test]
401    fn mul_csr_vec_ndarray() {
402        let indptr: &[usize] = &[0, 3, 3, 5, 6, 7];
403        let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4];
404        let data: &[f64] = &[
405            0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315,
406            0.39244208, 0.57202407,
407        ];
408
409        let mat = CsMatView::new((5, 5), indptr, indices, data);
410        let vec = arr1(&[0.1f64, 0.2, -0.1, 0.3, 0.9]);
411        let mut res_vec = Array::<f64, _>::zeros(5);
412        mul_acc_mat_vec_csr(mat, vec.view(), res_vec.view_mut());
413
414        let expected_output =
415            [0.22527496, 0., 0.17814121, 0.35319787, 0.51482166];
416
417        let epsilon = 1e-7; // TODO: get better values and increase precision
418
419        assert!(res_vec
420            .iter()
421            .zip(expected_output.iter())
422            .all(|(x, y)| (*x - *y).abs() < epsilon));
423    }
424
425    #[test]
426    fn mul_csr_csr() {
427        let a = mat1();
428        let res = &a * &a;
429        let expected_output = mat1_self_matprod();
430        assert_eq!(expected_output, res);
431
432        let b = mat2();
433        let res = &a * &b;
434        let expected_output = mat1_matprod_mat2();
435        assert_eq!(expected_output, res);
436    }
437
438    #[test]
439    fn mul_csc_csc() {
440        let a = mat1_csc();
441        let b = mat4();
442        let res = &a * &b;
443        let expected_output = mat1_csc_matprod_mat4();
444        assert_eq!(expected_output, res);
445    }
446
447    #[test]
448    fn mul_csc_csr() {
449        let a = mat1();
450        let a_ = mat1_csc();
451        let expected_output = mat1_self_matprod();
452
453        let res = &a * &a_;
454        assert_eq!(expected_output, res);
455
456        let res = (&a_ * &a).to_other_storage();
457        assert_eq!(expected_output, res);
458    }
459
460    #[test]
461    fn mul_csr_csvec() {
462        let a = mat1();
463        let v = CsVec::new(5, vec![0, 2, 4], vec![1.; 3]);
464        let res = &a * &v;
465        let expected_output = CsVec::new(5, vec![0, 1, 2], vec![3., 5., 5.]);
466        assert_eq!(expected_output, res);
467    }
468
469    #[test]
470    fn mul_csr_zero_csvec() {
471        let zero = CsVec::new(0, vec![], vec![]);
472        assert_eq!(&mat1() * &zero, zero);
473    }
474
475    #[test]
476    fn mul_csvec_csr() {
477        let a = mat1();
478        let v = CsVec::new(5, vec![0, 2, 4], vec![1.; 3]);
479        let res = &v * &a;
480        let expected_output = CsVec::new(5, vec![2, 3], vec![8., 11.]);
481        assert_eq!(expected_output, res);
482    }
483
484    #[test]
485    fn mul_csc_csvec() {
486        let a = mat1_csc();
487        let v = CsVec::new(5, vec![0, 2, 4], vec![1.; 3]);
488        let res = &a * &v;
489        let expected_output = CsVec::new(5, vec![0, 1, 2], vec![3., 5., 5.]);
490        assert_eq!(expected_output, res);
491    }
492
493    #[test]
494    fn mul_csvec_csc() {
495        let a = mat1_csc();
496        let v = CsVec::new(5, vec![0, 2, 4], vec![1.; 3]);
497        let res = &v * &a;
498        let expected_output = CsVec::new(5, vec![2, 3], vec![8., 11.]);
499        assert_eq!(expected_output, res);
500    }
501
502    #[test]
503    fn mul_csr_dense_rowmaj() {
504        let a: Array2<f64> = Array::eye(3);
505        let e: CsMat<f64> = CsMat::eye(3);
506        let mut res = Array::<f64, _>::zeros((3, 3));
507        super::csr_mulacc_dense_rowmaj(e.view(), a.view(), res.view_mut());
508        assert_eq!(res, a);
509
510        let a = mat1();
511        let b = mat_dense1();
512        let mut res = Array::<f64, _>::zeros((5, 5));
513        super::csr_mulacc_dense_rowmaj(a.view(), b.view(), res.view_mut());
514        let expected_output = arr2(&[
515            [24., 31., 24., 17., 10.],
516            [11., 18., 11., 9., 2.],
517            [20., 25., 20., 15., 10.],
518            [40., 48., 40., 32., 24.],
519            [21., 28., 21., 14., 7.],
520        ]);
521        assert_eq!(res, expected_output);
522
523        let c = &a * &b;
524        assert_eq!(c, expected_output);
525
526        let a = mat5();
527        let b = mat_dense2();
528        let mut res = Array::<f64, _>::zeros((5, 7));
529        super::csr_mulacc_dense_rowmaj(a.view(), b.view(), res.view_mut());
530        let expected_output = arr2(&[
531            [130.04, 150.1, 87.19, 90.89, 99.48, 80.43, 99.3],
532            [217.72, 161.61, 79.47, 121.5, 124.23, 146.91, 157.79],
533            [55.6, 59.95, 86.7, 0.9, 37.4, 71.66, 51.94],
534            [118.18, 123.16, 128.04, 92.02, 106.84, 175.1, 87.36],
535            [43.4, 54.1, 12.65, 44.35, 39.9, 23.4, 76.6],
536        ]);
537        let eps = 1e-8;
538        assert!(res
539            .iter()
540            .zip(expected_output.iter())
541            .all(|(&x, &y)| (x - y).abs() <= eps));
542    }
543
544    #[test]
545    fn mul_csc_dense_rowmaj() {
546        let a = mat1_csc();
547        let b = mat_dense1();
548        let mut res = Array::<f64, _>::zeros((5, 5));
549        super::csc_mulacc_dense_rowmaj(a.view(), b.view(), res.view_mut());
550        let expected_output = arr2(&[
551            [24., 31., 24., 17., 10.],
552            [11., 18., 11., 9., 2.],
553            [20., 25., 20., 15., 10.],
554            [40., 48., 40., 32., 24.],
555            [21., 28., 21., 14., 7.],
556        ]);
557        assert_eq!(res, expected_output);
558
559        let c = &a * &b;
560        assert_eq!(c, expected_output);
561    }
562
563    #[test]
564    fn mul_csc_dense_colmaj() {
565        let a = mat1_csc();
566        let b = mat_dense1_colmaj();
567        let mut res = Array::<f64, _>::zeros((5, 5).f());
568        super::csc_mulacc_dense_colmaj(a.view(), b.view(), res.view_mut());
569        let v = vec![
570            24., 11., 20., 40., 21., 31., 18., 25., 48., 28., 24., 11., 20.,
571            40., 21., 17., 9., 15., 32., 14., 10., 2., 10., 24., 7.,
572        ];
573        let expected_output = Array::from_shape_vec((5, 5).f(), v).unwrap();
574        assert_eq!(res, expected_output);
575
576        let c = &a * &b;
577        assert_eq!(c, expected_output);
578    }
579
580    #[test]
581    fn mul_csr_dense_colmaj() {
582        let a = mat1();
583        let b = mat_dense1_colmaj();
584        let mut res = Array::<f64, _>::zeros((5, 5).f());
585        super::csr_mulacc_dense_colmaj(a.view(), b.view(), res.view_mut());
586        let v = vec![
587            24., 11., 20., 40., 21., 31., 18., 25., 48., 28., 24., 11., 20.,
588            40., 21., 17., 9., 15., 32., 14., 10., 2., 10., 24., 7.,
589        ];
590        let expected_output = Array::from_shape_vec((5, 5).f(), v).unwrap();
591        assert_eq!(res, expected_output);
592
593        let c = &a * &b;
594        assert_eq!(c, expected_output);
595    }
596
597    // stolen from ndarray - not currently exported.
598    fn assert_close<D>(a: ArrayView<f64, D>, b: ArrayView<f64, D>)
599    where
600        D: Dimension,
601    {
602        let diff = (&a - &b).mapv_into(f64::abs);
603
604        let rtol = 1e-7;
605        let atol = 1e-12;
606        let crtol = b.mapv(|x| x.abs() * rtol);
607        let tol = crtol + atol;
608        let tol_m_diff = &diff - &tol;
609        let maxdiff = tol_m_diff.fold(0. / 0., |x, y| f64::max(x, *y));
610        println!("diff offset from tolerance level= {:.2e}", maxdiff);
611        if maxdiff > 0. {
612            println!("{:.4?}", a);
613            println!("{:.4?}", b);
614            panic!("results differ");
615        }
616    }
617
618    #[test]
619    #[cfg_attr(
620        miri,
621        ignore = "https://github.com/rust-ndarray/ndarray/issues/1178"
622    )]
623    fn test_sparse_dot_dense() {
624        let sparse = [
625            mat1(),
626            mat1_csc(),
627            mat2(),
628            mat2().transpose_into(),
629            mat4(),
630            mat5(),
631        ];
632        let dense = [
633            mat_dense1(),
634            mat_dense1_colmaj(),
635            mat_dense1().reversed_axes(),
636            mat_dense2(),
637            mat_dense2().reversed_axes(),
638        ];
639
640        // test sparse.dot(dense)
641        for s in sparse.iter() {
642            for d in dense.iter() {
643                if d.shape()[0] < s.cols() {
644                    continue;
645                }
646
647                let d = d.slice(s![0..s.cols(), ..]);
648
649                let truth = s.to_dense().dot(&d);
650                let test = s.dot(&d);
651                assert_close(test.view(), truth.view());
652            }
653        }
654    }
655
656    #[test]
657    #[cfg_attr(
658        miri,
659        ignore = "https://github.com/rust-ndarray/ndarray/issues/1178"
660    )]
661    fn test_dense_dot_sparse() {
662        let sparse = [
663            mat1(),
664            mat1_csc(),
665            mat2(),
666            mat2().transpose_into(),
667            mat4(),
668            mat5(),
669        ];
670        let dense = [
671            mat_dense1(),
672            mat_dense1_colmaj(),
673            mat_dense1().reversed_axes(),
674            mat_dense2(),
675            mat_dense2().reversed_axes(),
676        ];
677
678        // test sparse.ldot(dense)
679        for s in sparse.iter() {
680            for d in dense.iter() {
681                if d.shape()[1] < s.rows() {
682                    continue;
683                }
684
685                let d = d.slice(s![.., 0..s.rows()]);
686
687                let truth = d.dot(&s.to_dense());
688                let test = d.dot(s);
689                assert_close(test.view(), truth.view());
690            }
691        }
692    }
693}