1use crate::dense_vector::{DenseVector, DenseVectorMut};
4use crate::indexing::SpIndex;
5use crate::sparse::compressed::SpMatView;
6use crate::sparse::prelude::*;
7use crate::Ix2;
8use ndarray::{ArrayView, ArrayViewMut, Axis};
9use num_traits::Num;
10
11pub fn csvec_dot_by_binary_search<N, I, A, B>(
15 vec1: CsVecViewI<A, I>,
16 vec2: CsVecViewI<B, I>,
17) -> N
18where
19 I: SpIndex,
20 N: crate::MulAcc<A, B> + num_traits::Zero,
21{
22 if vec1.nnz() > vec2.nnz() {
26 csvec_dot_by_binary_search_impl(vec2, vec1, |acc: &mut N, a, b| {
27 acc.mul_acc(b, a);
28 })
29 } else {
30 csvec_dot_by_binary_search_impl(vec1, vec2, |acc: &mut N, a, b| {
31 acc.mul_acc(a, b);
32 })
33 }
34}
35
36pub(crate) fn csvec_dot_by_binary_search_impl<N, I, A, B, F>(
38 vec1: CsVecViewI<A, I>,
39 vec2: CsVecViewI<B, I>,
40 mul_acc: F,
41) -> N
42where
43 F: Fn(&mut N, &A, &B),
44 I: SpIndex,
45 N: num_traits::Zero,
46{
47 assert!(vec1.nnz() <= vec2.nnz());
48 let (mut idx1, mut val1, mut idx2, mut val2) =
50 (vec1.indices(), vec1.data(), vec2.indices(), vec2.data());
51
52 let mut sum = N::zero();
53 while !idx1.is_empty() && !idx2.is_empty() {
54 debug_assert_eq!(idx1.len(), val1.len());
55 debug_assert_eq!(idx2.len(), val2.len());
56
57 let (found, i) = match idx2.binary_search(&idx1[0]) {
58 Ok(i) => (true, i),
59 Err(i) => (false, i),
60 };
61 if found {
62 mul_acc(&mut sum, &val1[0], &val2[i]);
63 }
64 idx1 = &idx1[1..];
65 val1 = &val1[1..];
66 idx2 = &idx2[i..];
67 val2 = &val2[i..];
68 }
69 sum
70}
71
72pub fn mul_acc_mat_vec_csc<N, I, Iptr, V, VRes, A, B>(
75 mat: CsMatViewI<A, I, Iptr>,
76 in_vec: V,
77 mut res_vec: VRes,
78) where
79 N: crate::MulAcc<A, B>,
80 I: SpIndex,
81 Iptr: SpIndex,
82 V: DenseVector<Scalar = B>,
83 VRes: DenseVectorMut<Scalar = N>,
84{
85 let mat = mat.view();
86 assert!(
87 mat.cols() == in_vec.dim() && mat.rows() == res_vec.dim(),
88 "Dimension mismatch"
89 );
90 assert!(mat.is_csc(), "Storage mismatch");
91
92 for (col_ind, vec) in mat.outer_iterator().enumerate() {
93 let vec_elem = in_vec.index(col_ind);
94 for (row_ind, mtx_elem) in vec.iter() {
95 res_vec.index_mut(row_ind).mul_acc(mtx_elem, vec_elem);
97 }
98 }
99}
100
101pub fn mul_acc_mat_vec_csr<N, A, B, I, Iptr, V, VRes>(
104 mat: CsMatViewI<A, I, Iptr>,
105 in_vec: V,
106 mut res_vec: VRes,
107) where
108 N: crate::MulAcc<A, B>,
109 I: SpIndex,
110 Iptr: SpIndex,
111 V: DenseVector<Scalar = B>,
112 VRes: DenseVectorMut<Scalar = N>,
113{
114 assert!(
115 mat.cols() == in_vec.dim() && mat.rows() == res_vec.dim(),
116 "Dimension mismatch"
117 );
118 assert!(mat.is_csr(), "Storage mismatch");
119
120 for (row_ind, vec) in mat.outer_iterator().enumerate() {
121 let tv = res_vec.index_mut(row_ind);
122 for (col_ind, mtx_elem) in vec.iter() {
123 tv.mul_acc(mtx_elem, in_vec.index(col_ind));
125 }
126 }
127}
128
129pub fn workspace_csr<N, A, B, I, Iptr, Mat1, Mat2>(
131 _: &Mat1,
132 rhs: &Mat2,
133) -> Vec<N>
134where
135 N: Clone + Num,
136 I: SpIndex,
137 Iptr: SpIndex,
138 Mat1: SpMatView<A, I, Iptr>,
139 Mat2: SpMatView<B, I, Iptr>,
140{
141 let len = rhs.view().cols();
142 vec![N::zero(); len]
143}
144
145pub fn workspace_csc<N, A, B, I, Iptr, Mat1, Mat2>(
147 lhs: &Mat1,
148 _: &Mat2,
149) -> Vec<N>
150where
151 N: Clone + Num,
152 I: SpIndex,
153 Iptr: SpIndex,
154 Mat1: SpMatView<A, I, Iptr>,
155 Mat2: SpMatView<B, I, Iptr>,
156{
157 let len = lhs.view().rows();
158 vec![N::zero(); len]
159}
160
161pub fn csr_mul_csvec<N, A, B, I, Iptr>(
163 lhs: CsMatViewI<A, I, Iptr>,
164 rhs: CsVecViewI<B, I>,
165) -> CsVecI<N, I>
166where
167 N: crate::MulAcc<A, B> + num_traits::Zero + PartialEq + Clone,
168 I: SpIndex,
169 Iptr: SpIndex,
170{
171 if rhs.dim == 0 {
172 return CsVecI::empty(0);
174 }
175 assert_eq!(lhs.cols(), rhs.dim(), "Dimension mismatch");
176 let mut res = CsVecI::empty(lhs.rows());
177 for (row_ind, lvec) in lhs.outer_iterator().enumerate() {
178 let val = lvec.dot_acc(&rhs);
179 if val != N::zero() {
180 res.append(row_ind, val);
181 }
182 }
183 res
184}
185
186pub fn csr_mulacc_dense_rowmaj<'a, N, A, B, I, Iptr>(
190 lhs: CsMatViewI<A, I, Iptr>,
191 rhs: ArrayView<B, Ix2>,
192 mut out: ArrayViewMut<'a, N, Ix2>,
193) where
194 N: 'a + crate::MulAcc<A, B>,
195 I: 'a + SpIndex,
196 Iptr: 'a + SpIndex,
197{
198 assert_eq!(lhs.cols(), rhs.shape()[0], "Dimension mismatch");
199 assert_eq!(lhs.rows(), out.shape()[0], "Dimension mismatch");
200 assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
201 assert!(lhs.is_csr(), "Storage mismatch");
202
203 let axis0 = Axis(0);
204 for (line, mut oline) in lhs.outer_iterator().zip(out.axis_iter_mut(axis0))
205 {
206 for (col_ind, lval) in line.iter() {
207 let rline = rhs.row(col_ind);
208 for (oval, rval) in oline.iter_mut().zip(rline.iter()) {
210 oval.mul_acc(lval, rval);
211 }
212 }
213 }
214}
215
216pub fn csc_mulacc_dense_rowmaj<'a, N, A, B, I, Iptr>(
220 lhs: CsMatViewI<A, I, Iptr>,
221 rhs: ArrayView<B, Ix2>,
222 mut out: ArrayViewMut<'a, N, Ix2>,
223) where
224 N: 'a + crate::MulAcc<A, B>,
225 I: 'a + SpIndex,
226 Iptr: 'a + SpIndex,
227{
228 assert_eq!(lhs.cols(), rhs.shape()[0], "Dimension mismatch");
229 assert_eq!(lhs.rows(), out.shape()[0], "Dimension mismatch");
230 assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
231 assert!(lhs.is_csc(), "Storage mismatch");
232
233 for (lcol, rline) in lhs.outer_iterator().zip(rhs.outer_iter()) {
234 for (orow, lval) in lcol.iter() {
235 let mut oline = out.row_mut(orow);
236 for (oval, rval) in oline.iter_mut().zip(rline.iter()) {
237 oval.mul_acc(lval, rval);
238 }
239 }
240 }
241}
242
243pub fn csc_mulacc_dense_colmaj<'a, N, A, B, I, Iptr>(
247 lhs: CsMatViewI<A, I, Iptr>,
248 rhs: ArrayView<B, Ix2>,
249 mut out: ArrayViewMut<'a, N, Ix2>,
250) where
251 N: 'a + crate::MulAcc<A, B>,
252 I: 'a + SpIndex,
253 Iptr: 'a + SpIndex,
254{
255 assert_eq!(lhs.cols(), rhs.shape()[0], "Dimension mismatch");
256 assert_eq!(lhs.rows(), out.shape()[0], "Dimension mismatch");
257 assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
258 assert!(lhs.is_csc(), "Storage mismatch");
259
260 let axis1 = Axis(1);
261 for (mut ocol, rcol) in out.axis_iter_mut(axis1).zip(rhs.axis_iter(axis1)) {
262 for (rrow, lcol) in lhs.outer_iterator().enumerate() {
263 let rval = &rcol[[rrow]];
264 for (orow, lval) in lcol.iter() {
265 ocol[[orow]].mul_acc(lval, rval);
266 }
267 }
268 }
269}
270
271pub fn csr_mulacc_dense_colmaj<'a, N, A, B, I, Iptr>(
275 lhs: CsMatViewI<A, I, Iptr>,
276 rhs: ArrayView<B, Ix2>,
277 mut out: ArrayViewMut<'a, N, Ix2>,
278) where
279 N: 'a + crate::MulAcc<A, B>,
280 I: 'a + SpIndex,
281 Iptr: 'a + SpIndex,
282{
283 assert_eq!(lhs.cols(), rhs.shape()[0], "Dimension mismatch");
284 assert_eq!(lhs.rows(), out.shape()[0], "Dimension mismatch");
285 assert_eq!(rhs.shape()[1], out.shape()[1], "Dimension mismatch");
286 assert!(lhs.is_csr(), "Storage mismatch");
287
288 let axis1 = Axis(1);
289 for (mut ocol, rcol) in out.axis_iter_mut(axis1).zip(rhs.axis_iter(axis1)) {
290 for (orow, lrow) in lhs.outer_iterator().enumerate() {
291 let oval = &mut ocol[[orow]];
292 for (rrow, lval) in lrow.iter() {
293 let rval = &rcol[[rrow]];
294 oval.mul_acc(lval, rval);
295 }
296 }
297 }
298}
299
300#[cfg(test)]
301mod test {
302 use super::*;
303 use crate::sparse::{CsMat, CsMatView, CsVec};
304 use crate::test_data::{
305 mat1, mat1_csc, mat1_csc_matprod_mat4, mat1_matprod_mat2,
306 mat1_self_matprod, mat2, mat4, mat5, mat_dense1, mat_dense1_colmaj,
307 mat_dense2,
308 };
309 use ndarray::linalg::Dot;
310 use ndarray::{arr1, arr2, s, Array, Array2, Dimension, ShapeBuilder};
311
312 #[test]
313 fn test_csvec_dot_by_binary_search() {
314 let vec1 = CsVecI::new(8, vec![0, 2, 4, 6], vec![1.; 4]);
315 let vec2 = CsVecI::new(8, vec![1, 3, 5, 7], vec![2.; 4]);
316 let vec3 = CsVecI::new(8, vec![1, 2, 5, 6], vec![3.; 4]);
317
318 assert_eq!(0., csvec_dot_by_binary_search(vec1.view(), vec2.view()));
319 assert_eq!(4., csvec_dot_by_binary_search(vec1.view(), vec1.view()));
320 assert_eq!(16., csvec_dot_by_binary_search(vec2.view(), vec2.view()));
321 assert_eq!(6., csvec_dot_by_binary_search(vec1.view(), vec3.view()));
322 assert_eq!(12., csvec_dot_by_binary_search(vec2.view(), vec3.view()));
323 }
324
325 #[test]
326 fn mul_csc_vec() {
327 let indptr: &[usize] = &[0, 2, 4, 5, 6, 7];
328 let indices: &[usize] = &[2, 3, 3, 4, 2, 1, 3];
329 let data: &[f64] = &[
330 0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123,
331 0.88132896, 0.72527863,
332 ];
333
334 let mat = CsMatView::new_csc((5, 5), indptr, indices, data);
335 let vector = vec![0.1, 0.2, -0.1, 0.3, 0.9];
336 let mut res_vec = vec![0., 0., 0., 0., 0.];
337 mul_acc_mat_vec_csc(mat, &vector, &mut res_vec);
338
339 let expected_output =
340 vec![0., 0.26439869, -0.01803924, 0.75120319, 0.11616419];
341
342 let epsilon = 1e-7; assert!(res_vec
345 .iter()
346 .zip(expected_output.iter())
347 .all(|(x, y)| (*x - *y).abs() < epsilon));
348 }
349
350 #[test]
351 fn mul_csc_vec_ndarray() {
352 let indptr: &[usize] = &[0, 2, 4, 5, 6, 7];
353 let indices: &[usize] = &[2, 3, 3, 4, 2, 1, 3];
354 let data: &[f64] = &[
355 0.35310881, 0.42380633, 0.28035896, 0.58082095, 0.53350123,
356 0.88132896, 0.72527863,
357 ];
358
359 let mat = CsMatView::new_csc((5, 5), indptr, indices, data);
360 let vector = arr1(&[0.1f64, 0.2, -0.1, 0.3, 0.9]);
361 let mut res_vec = Array::<f64, _>::zeros(5);
362 mul_acc_mat_vec_csc(mat, vector, res_vec.view_mut());
363
364 let expected_output =
365 vec![0., 0.26439869, -0.01803924, 0.75120319, 0.11616419];
366
367 let epsilon = 1e-7; assert!(res_vec
370 .iter()
371 .zip(expected_output.iter())
372 .all(|(x, y)| (*x - *y).abs() < epsilon));
373 }
374
375 #[test]
376 fn mul_csr_vec() {
377 let indptr: &[usize] = &[0, 3, 3, 5, 6, 7];
378 let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4];
379 let data: &[f64] = &[
380 0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315,
381 0.39244208, 0.57202407,
382 ];
383
384 let mat = CsMatView::new((5, 5), indptr, indices, data);
385 let slice: &[f64] = &[0.1, 0.2, -0.1, 0.3, 0.9];
386 let mut res_vec = vec![0., 0., 0., 0., 0.];
387 mul_acc_mat_vec_csr(mat, slice, &mut res_vec);
388
389 let expected_output =
390 vec![0.22527496, 0., 0.17814121, 0.35319787, 0.51482166];
391
392 let epsilon = 1e-7; assert!(res_vec
395 .iter()
396 .zip(expected_output.iter())
397 .all(|(x, y)| (*x - *y).abs() < epsilon));
398 }
399
400 #[test]
401 fn mul_csr_vec_ndarray() {
402 let indptr: &[usize] = &[0, 3, 3, 5, 6, 7];
403 let indices: &[usize] = &[1, 2, 3, 2, 3, 4, 4];
404 let data: &[f64] = &[
405 0.75672424, 0.1649078, 0.30140296, 0.10358244, 0.6283315,
406 0.39244208, 0.57202407,
407 ];
408
409 let mat = CsMatView::new((5, 5), indptr, indices, data);
410 let vec = arr1(&[0.1f64, 0.2, -0.1, 0.3, 0.9]);
411 let mut res_vec = Array::<f64, _>::zeros(5);
412 mul_acc_mat_vec_csr(mat, vec.view(), res_vec.view_mut());
413
414 let expected_output =
415 [0.22527496, 0., 0.17814121, 0.35319787, 0.51482166];
416
417 let epsilon = 1e-7; assert!(res_vec
420 .iter()
421 .zip(expected_output.iter())
422 .all(|(x, y)| (*x - *y).abs() < epsilon));
423 }
424
425 #[test]
426 fn mul_csr_csr() {
427 let a = mat1();
428 let res = &a * &a;
429 let expected_output = mat1_self_matprod();
430 assert_eq!(expected_output, res);
431
432 let b = mat2();
433 let res = &a * &b;
434 let expected_output = mat1_matprod_mat2();
435 assert_eq!(expected_output, res);
436 }
437
438 #[test]
439 fn mul_csc_csc() {
440 let a = mat1_csc();
441 let b = mat4();
442 let res = &a * &b;
443 let expected_output = mat1_csc_matprod_mat4();
444 assert_eq!(expected_output, res);
445 }
446
447 #[test]
448 fn mul_csc_csr() {
449 let a = mat1();
450 let a_ = mat1_csc();
451 let expected_output = mat1_self_matprod();
452
453 let res = &a * &a_;
454 assert_eq!(expected_output, res);
455
456 let res = (&a_ * &a).to_other_storage();
457 assert_eq!(expected_output, res);
458 }
459
460 #[test]
461 fn mul_csr_csvec() {
462 let a = mat1();
463 let v = CsVec::new(5, vec![0, 2, 4], vec![1.; 3]);
464 let res = &a * &v;
465 let expected_output = CsVec::new(5, vec![0, 1, 2], vec![3., 5., 5.]);
466 assert_eq!(expected_output, res);
467 }
468
469 #[test]
470 fn mul_csr_zero_csvec() {
471 let zero = CsVec::new(0, vec![], vec![]);
472 assert_eq!(&mat1() * &zero, zero);
473 }
474
475 #[test]
476 fn mul_csvec_csr() {
477 let a = mat1();
478 let v = CsVec::new(5, vec![0, 2, 4], vec![1.; 3]);
479 let res = &v * &a;
480 let expected_output = CsVec::new(5, vec![2, 3], vec![8., 11.]);
481 assert_eq!(expected_output, res);
482 }
483
484 #[test]
485 fn mul_csc_csvec() {
486 let a = mat1_csc();
487 let v = CsVec::new(5, vec![0, 2, 4], vec![1.; 3]);
488 let res = &a * &v;
489 let expected_output = CsVec::new(5, vec![0, 1, 2], vec![3., 5., 5.]);
490 assert_eq!(expected_output, res);
491 }
492
493 #[test]
494 fn mul_csvec_csc() {
495 let a = mat1_csc();
496 let v = CsVec::new(5, vec![0, 2, 4], vec![1.; 3]);
497 let res = &v * &a;
498 let expected_output = CsVec::new(5, vec![2, 3], vec![8., 11.]);
499 assert_eq!(expected_output, res);
500 }
501
502 #[test]
503 fn mul_csr_dense_rowmaj() {
504 let a: Array2<f64> = Array::eye(3);
505 let e: CsMat<f64> = CsMat::eye(3);
506 let mut res = Array::<f64, _>::zeros((3, 3));
507 super::csr_mulacc_dense_rowmaj(e.view(), a.view(), res.view_mut());
508 assert_eq!(res, a);
509
510 let a = mat1();
511 let b = mat_dense1();
512 let mut res = Array::<f64, _>::zeros((5, 5));
513 super::csr_mulacc_dense_rowmaj(a.view(), b.view(), res.view_mut());
514 let expected_output = arr2(&[
515 [24., 31., 24., 17., 10.],
516 [11., 18., 11., 9., 2.],
517 [20., 25., 20., 15., 10.],
518 [40., 48., 40., 32., 24.],
519 [21., 28., 21., 14., 7.],
520 ]);
521 assert_eq!(res, expected_output);
522
523 let c = &a * &b;
524 assert_eq!(c, expected_output);
525
526 let a = mat5();
527 let b = mat_dense2();
528 let mut res = Array::<f64, _>::zeros((5, 7));
529 super::csr_mulacc_dense_rowmaj(a.view(), b.view(), res.view_mut());
530 let expected_output = arr2(&[
531 [130.04, 150.1, 87.19, 90.89, 99.48, 80.43, 99.3],
532 [217.72, 161.61, 79.47, 121.5, 124.23, 146.91, 157.79],
533 [55.6, 59.95, 86.7, 0.9, 37.4, 71.66, 51.94],
534 [118.18, 123.16, 128.04, 92.02, 106.84, 175.1, 87.36],
535 [43.4, 54.1, 12.65, 44.35, 39.9, 23.4, 76.6],
536 ]);
537 let eps = 1e-8;
538 assert!(res
539 .iter()
540 .zip(expected_output.iter())
541 .all(|(&x, &y)| (x - y).abs() <= eps));
542 }
543
544 #[test]
545 fn mul_csc_dense_rowmaj() {
546 let a = mat1_csc();
547 let b = mat_dense1();
548 let mut res = Array::<f64, _>::zeros((5, 5));
549 super::csc_mulacc_dense_rowmaj(a.view(), b.view(), res.view_mut());
550 let expected_output = arr2(&[
551 [24., 31., 24., 17., 10.],
552 [11., 18., 11., 9., 2.],
553 [20., 25., 20., 15., 10.],
554 [40., 48., 40., 32., 24.],
555 [21., 28., 21., 14., 7.],
556 ]);
557 assert_eq!(res, expected_output);
558
559 let c = &a * &b;
560 assert_eq!(c, expected_output);
561 }
562
563 #[test]
564 fn mul_csc_dense_colmaj() {
565 let a = mat1_csc();
566 let b = mat_dense1_colmaj();
567 let mut res = Array::<f64, _>::zeros((5, 5).f());
568 super::csc_mulacc_dense_colmaj(a.view(), b.view(), res.view_mut());
569 let v = vec![
570 24., 11., 20., 40., 21., 31., 18., 25., 48., 28., 24., 11., 20.,
571 40., 21., 17., 9., 15., 32., 14., 10., 2., 10., 24., 7.,
572 ];
573 let expected_output = Array::from_shape_vec((5, 5).f(), v).unwrap();
574 assert_eq!(res, expected_output);
575
576 let c = &a * &b;
577 assert_eq!(c, expected_output);
578 }
579
580 #[test]
581 fn mul_csr_dense_colmaj() {
582 let a = mat1();
583 let b = mat_dense1_colmaj();
584 let mut res = Array::<f64, _>::zeros((5, 5).f());
585 super::csr_mulacc_dense_colmaj(a.view(), b.view(), res.view_mut());
586 let v = vec![
587 24., 11., 20., 40., 21., 31., 18., 25., 48., 28., 24., 11., 20.,
588 40., 21., 17., 9., 15., 32., 14., 10., 2., 10., 24., 7.,
589 ];
590 let expected_output = Array::from_shape_vec((5, 5).f(), v).unwrap();
591 assert_eq!(res, expected_output);
592
593 let c = &a * &b;
594 assert_eq!(c, expected_output);
595 }
596
597 fn assert_close<D>(a: ArrayView<f64, D>, b: ArrayView<f64, D>)
599 where
600 D: Dimension,
601 {
602 let diff = (&a - &b).mapv_into(f64::abs);
603
604 let rtol = 1e-7;
605 let atol = 1e-12;
606 let crtol = b.mapv(|x| x.abs() * rtol);
607 let tol = crtol + atol;
608 let tol_m_diff = &diff - &tol;
609 let maxdiff = tol_m_diff.fold(0. / 0., |x, y| f64::max(x, *y));
610 println!("diff offset from tolerance level= {:.2e}", maxdiff);
611 if maxdiff > 0. {
612 println!("{:.4?}", a);
613 println!("{:.4?}", b);
614 panic!("results differ");
615 }
616 }
617
618 #[test]
619 #[cfg_attr(
620 miri,
621 ignore = "https://github.com/rust-ndarray/ndarray/issues/1178"
622 )]
623 fn test_sparse_dot_dense() {
624 let sparse = [
625 mat1(),
626 mat1_csc(),
627 mat2(),
628 mat2().transpose_into(),
629 mat4(),
630 mat5(),
631 ];
632 let dense = [
633 mat_dense1(),
634 mat_dense1_colmaj(),
635 mat_dense1().reversed_axes(),
636 mat_dense2(),
637 mat_dense2().reversed_axes(),
638 ];
639
640 for s in sparse.iter() {
642 for d in dense.iter() {
643 if d.shape()[0] < s.cols() {
644 continue;
645 }
646
647 let d = d.slice(s![0..s.cols(), ..]);
648
649 let truth = s.to_dense().dot(&d);
650 let test = s.dot(&d);
651 assert_close(test.view(), truth.view());
652 }
653 }
654 }
655
656 #[test]
657 #[cfg_attr(
658 miri,
659 ignore = "https://github.com/rust-ndarray/ndarray/issues/1178"
660 )]
661 fn test_dense_dot_sparse() {
662 let sparse = [
663 mat1(),
664 mat1_csc(),
665 mat2(),
666 mat2().transpose_into(),
667 mat4(),
668 mat5(),
669 ];
670 let dense = [
671 mat_dense1(),
672 mat_dense1_colmaj(),
673 mat_dense1().reversed_axes(),
674 mat_dense2(),
675 mat_dense2().reversed_axes(),
676 ];
677
678 for s in sparse.iter() {
680 for d in dense.iter() {
681 if d.shape()[1] < s.rows() {
682 continue;
683 }
684
685 let d = d.slice(s![.., 0..s.rows()]);
686
687 let truth = d.dot(&s.to_dense());
688 let test = d.dot(s);
689 assert_close(test.view(), truth.view());
690 }
691 }
692 }
693}