1use std::fmt::Debug;
2use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
3
4use crate::error::DiffsolError;
5use crate::scalar::Scale;
6use crate::vector::VectorHost;
7use crate::{Context, IndexType, Scalar, Vector, VectorIndex};
8
9use extract_block::combine;
10use num_traits::{One, Zero};
11use sparsity::{Dense, MatrixSparsity, MatrixSparsityRef};
12
13#[cfg(feature = "cuda")]
14pub mod cuda;
15
16#[cfg(feature = "nalgebra")]
17pub mod dense_nalgebra_serial;
18
19#[cfg(feature = "faer")]
20pub mod dense_faer_serial;
21
22#[cfg(feature = "faer")]
23pub mod sparse_faer;
24
25pub mod default_solver;
26pub mod extract_block;
27pub mod sparsity;
28
29#[macro_use]
30mod utils;
31
32pub trait MatrixCommon: Sized + Debug {
34 type V: Vector<T = Self::T, C = Self::C, Index: VectorIndex<C = Self::C>>;
35 type T: Scalar;
36 type C: Context;
37 type Inner;
38
39 fn nrows(&self) -> IndexType;
41 fn ncols(&self) -> IndexType;
43 fn inner(&self) -> &Self::Inner;
45}
46
47impl<M> MatrixCommon for &M
48where
49 M: MatrixCommon,
50{
51 type T = M::T;
52 type V = M::V;
53 type C = M::C;
54 type Inner = M::Inner;
55
56 fn nrows(&self) -> IndexType {
57 M::nrows(*self)
58 }
59 fn ncols(&self) -> IndexType {
60 M::ncols(*self)
61 }
62 fn inner(&self) -> &Self::Inner {
63 M::inner(*self)
64 }
65}
66
67impl<M> MatrixCommon for &mut M
68where
69 M: MatrixCommon,
70{
71 type T = M::T;
72 type V = M::V;
73 type C = M::C;
74 type Inner = M::Inner;
75
76 fn ncols(&self) -> IndexType {
77 M::ncols(*self)
78 }
79 fn nrows(&self) -> IndexType {
80 M::nrows(*self)
81 }
82 fn inner(&self) -> &Self::Inner {
83 M::inner(*self)
84 }
85}
86
87pub trait MatrixOpsByValue<Rhs = Self, Output = Self>:
91 MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
92{
93}
94
95impl<M, Rhs, Output> MatrixOpsByValue<Rhs, Output> for M where
96 M: MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
97{
98}
99
100pub trait MatrixMutOpsByValue<Rhs = Self>: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
104
105impl<M, Rhs> MatrixMutOpsByValue<Rhs> for M where M: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
106
107pub trait MatrixRef<M: MatrixCommon>: Mul<Scale<M::T>, Output = M> {}
109impl<RefT, M: MatrixCommon> MatrixRef<M> for RefT where RefT: Mul<Scale<M::T>, Output = M> {}
110
111pub trait MatrixViewMut<'a>:
117 for<'b> MatrixMutOpsByValue<&'b Self>
118 + for<'b> MatrixMutOpsByValue<&'b Self::View>
119 + MulAssign<Scale<Self::T>>
120{
121 type Owned;
122 type View;
123 fn into_owned(self) -> Self::Owned;
125 fn gemm_oo(&mut self, alpha: Self::T, a: &Self::Owned, b: &Self::Owned, beta: Self::T);
127 fn gemm_vo(&mut self, alpha: Self::T, a: &Self::View, b: &Self::Owned, beta: Self::T);
129}
130
131pub trait MatrixView<'a>:
137 for<'b> MatrixOpsByValue<&'b Self::Owned, Self::Owned> + Mul<Scale<Self::T>, Output = Self::Owned>
138{
139 type Owned;
140
141 fn into_owned(self) -> Self::Owned;
143
144 fn gemv_v(
146 &self,
147 alpha: Self::T,
148 x: &<Self::V as Vector>::View<'_>,
149 beta: Self::T,
150 y: &mut Self::V,
151 );
152
153 fn gemv_o(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
155}
156
157pub trait Matrix:
170 MatrixCommon + Mul<Scale<Self::T>, Output = Self> + Clone + Send + 'static
171{
172 type Sparsity: MatrixSparsity<Self>;
173 type SparsityRef<'a>: MatrixSparsityRef<'a, Self>
174 where
175 Self: 'a;
176
177 fn sparsity(&self) -> Option<Self::SparsityRef<'_>>;
179
180 fn context(&self) -> &Self::C;
182
183 fn inner_mut(&mut self) -> &mut Self::Inner;
185
186 fn is_sparse() -> bool {
188 Self::zeros(1, 1, Default::default()).sparsity().is_some()
189 }
190
191 fn partition_indices_by_zero_diagonal(
196 &self,
197 ) -> (<Self::V as Vector>::Index, <Self::V as Vector>::Index);
198
199 fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
201
202 fn copy_from(&mut self, other: &Self);
204
205 fn zeros(nrows: IndexType, ncols: IndexType, ctx: Self::C) -> Self;
207
208 fn new_from_sparsity(
210 nrows: IndexType,
211 ncols: IndexType,
212 sparsity: Option<Self::Sparsity>,
213 ctx: Self::C,
214 ) -> Self;
215
216 fn from_diagonal(v: &Self::V) -> Self;
218
219 fn set_column(&mut self, j: IndexType, v: &Self::V);
223
224 fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V);
226
227 fn set_data_with_indices(
233 &mut self,
234 dst_indices: &<Self::V as Vector>::Index,
235 src_indices: &<Self::V as Vector>::Index,
236 data: &Self::V,
237 );
238
239 fn gather(&mut self, other: &Self, indices: &<Self::V as Vector>::Index);
245
246 fn split(
262 &self,
263 algebraic_indices: &<Self::V as Vector>::Index,
264 ) -> [(Self, <Self::V as Vector>::Index); 4] {
265 match self.sparsity() {
266 Some(sp) => sp.split(algebraic_indices).map(|(sp, src_indices)| {
267 let mut m = Self::new_from_sparsity(
268 sp.nrows(),
269 sp.ncols(),
270 Some(sp),
271 self.context().clone(),
272 );
273 m.gather(self, &src_indices);
274 (m, src_indices)
275 }),
276 None => Dense::<Self>::new(self.nrows(), self.ncols())
277 .split(algebraic_indices)
278 .map(|(sp, src_indices)| {
279 let mut m = Self::new_from_sparsity(
280 sp.nrows(),
281 sp.ncols(),
282 None,
283 self.context().clone(),
284 );
285 m.gather(self, &src_indices);
286 (m, src_indices)
287 }),
288 }
289 }
290
291 fn combine(
296 ul: &Self,
297 ur: &Self,
298 ll: &Self,
299 lr: &Self,
300 algebraic_indices: &<Self::V as Vector>::Index,
301 ) -> Self {
302 combine(ul, ur, ll, lr, algebraic_indices)
303 }
304
305 fn scale_add_and_assign(&mut self, x: &Self, beta: Self::T, y: &Self);
310
311 fn triplet_iter(
318 &self,
319 ) -> (
320 impl Iterator<Item = (IndexType, IndexType)> + '_,
321 impl Iterator<Item = Self::T> + '_,
322 );
323
324 fn try_from_triplets(
330 nrows: IndexType,
331 ncols: IndexType,
332 indices: Vec<(IndexType, IndexType)>,
333 values: Vec<Self::T>,
334 ctx: Self::C,
335 ) -> Result<Self, DiffsolError>;
336}
337
338pub trait MatrixHost: Matrix<V: VectorHost> {}
343
344impl<T: Matrix<V: VectorHost>> MatrixHost for T {}
345
346pub trait DenseMatrix:
358 Matrix
359 + for<'b> MatrixOpsByValue<&'b Self, Self>
360 + for<'b> MatrixMutOpsByValue<&'b Self>
361 + for<'a, 'b> MatrixOpsByValue<&'b Self::View<'a>, Self>
362 + for<'a, 'b> MatrixMutOpsByValue<&'b Self::View<'a>>
363{
364 type View<'a>: MatrixView<'a, Owned = Self, T = Self::T, V = Self::V>
366 where
367 Self: 'a;
368
369 type ViewMut<'a>: MatrixViewMut<
371 'a,
372 Owned = Self,
373 T = Self::T,
374 V = Self::V,
375 View = Self::View<'a>,
376 >
377 where
378 Self: 'a;
379
380 fn gemm(&mut self, alpha: Self::T, a: &Self, b: &Self, beta: Self::T);
382
383 fn column_axpy(&mut self, alpha: Self::T, j: IndexType, i: IndexType);
387
388 fn columns(&self, start: IndexType, end: IndexType) -> Self::View<'_>;
390
391 fn column(&self, i: IndexType) -> <Self::V as Vector>::View<'_>;
393
394 fn columns_mut(&mut self, start: IndexType, end: IndexType) -> Self::ViewMut<'_>;
396
397 fn column_mut(&mut self, i: IndexType) -> <Self::V as Vector>::ViewMut<'_>;
399
400 fn set_index(&mut self, i: IndexType, j: IndexType, value: Self::T);
402
403 fn get_index(&self, i: IndexType, j: IndexType) -> Self::T;
405
406 fn mat_mul(&self, b: &Self) -> Self {
408 let nrows = self.nrows();
409 let ncols = b.ncols();
410 let mut ret = Self::zeros(nrows, ncols, self.context().clone());
411 ret.gemm(Self::T::one(), self, b, Self::T::zero());
412 ret
413 }
414
415 fn resize_cols(&mut self, ncols: IndexType);
419
420 fn from_vec(nrows: IndexType, ncols: IndexType, data: Vec<Self::T>, ctx: Self::C) -> Self;
424}
425
426#[cfg(test)]
427pub(crate) mod tests {
428 use super::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut};
429 use crate::scalar::Scale;
430 use crate::{scalar::IndexType, Context, Vector, VectorIndex};
431 use num_traits::{FromPrimitive, One, Zero};
432
433 fn f<M: Matrix>(x: f64) -> M::T {
434 M::T::from_f64(x).unwrap()
435 }
436
437 fn triplet_values<M: Matrix>(m: &M) -> Vec<M::T> {
438 let (_, vals) = m.triplet_iter();
439 vals.collect()
440 }
441
442 fn triplet_indices<M: Matrix>(m: &M) -> Vec<(IndexType, IndexType)> {
443 let (idx, _) = m.triplet_iter();
444 idx.collect()
445 }
446
447 pub fn test_partition_indices_by_zero_diagonal<M: Matrix>() {
448 let indices = vec![(0, 0), (1, 1), (3, 3)];
449 let values = vec![M::T::one(), M::T::from_f64(2.0).unwrap(), M::T::one()];
450 let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
451 let (zero_diagonal_indices, non_zero_diagonal_indices) =
452 m.partition_indices_by_zero_diagonal();
453 assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
454 assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
455
456 let indices = vec![(0, 0), (1, 1), (2, 2), (3, 3)];
457 let values = vec![
458 M::T::one(),
459 M::T::from_f64(2.0).unwrap(),
460 M::T::zero(),
461 M::T::one(),
462 ];
463 let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
464 let (zero_diagonal_indices, non_zero_diagonal_indices) =
465 m.partition_indices_by_zero_diagonal();
466 assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
467 assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
468
469 let indices = vec![(0, 0), (1, 1), (2, 2), (3, 3)];
470 let values = vec![
471 M::T::one(),
472 M::T::from_f64(2.0).unwrap(),
473 M::T::from_f64(3.0).unwrap(),
474 M::T::one(),
475 ];
476 let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
477 let (zero_diagonal_indices, non_zero_diagonal_indices) =
478 m.partition_indices_by_zero_diagonal();
479 assert_eq!(
480 zero_diagonal_indices.clone_as_vec(),
481 Vec::<IndexType>::new()
482 );
483 assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 2, 3]);
484 }
485
486 pub fn test_zeros<M: Matrix>() {
489 let a = M::zeros(2, 3, Default::default());
490 assert_eq!(a.nrows(), 2);
491 assert_eq!(a.ncols(), 3);
492 let vals = triplet_values(&a);
493 assert!(vals.is_empty() || vals.iter().all(|v| v.is_zero()));
494 }
495
496 pub fn test_from_diagonal<M: Matrix>() {
497 let v = M::V::from_vec(
498 vec![f::<M>(2.0), f::<M>(3.0), f::<M>(5.0)],
499 Default::default(),
500 );
501 let a = M::from_diagonal(&v);
502 assert_eq!(a.nrows(), 3);
503 assert_eq!(a.ncols(), 3);
504 let idx = triplet_indices(&a);
505 let vals = triplet_values(&a);
506 for &(i, j) in &idx {
508 let pos = idx.iter().position(|&x| x == (i, j)).unwrap();
509 if i == j {
510 assert!(
511 vals[pos] != M::T::zero(),
512 "diagonal entry should be non-zero"
513 );
514 } else {
515 assert!(vals[pos].is_zero(), "off-diagonal entry should be zero");
516 }
517 }
518 }
519
520 pub fn test_from_diagonal_dense<M: DenseMatrix>() {
521 let v = M::V::from_vec(
522 vec![f::<M>(2.0), f::<M>(3.0), f::<M>(5.0)],
523 Default::default(),
524 );
525 let a = M::from_diagonal(&v);
526 assert_eq!(a.nrows(), 3);
527 assert_eq!(a.ncols(), 3);
528 assert_eq!(a.get_index(0, 0), f::<M>(2.0));
529 assert_eq!(a.get_index(1, 1), f::<M>(3.0));
530 assert_eq!(a.get_index(2, 2), f::<M>(5.0));
531 assert_eq!(a.get_index(0, 1), f::<M>(0.0));
532 assert_eq!(a.get_index(1, 0), f::<M>(0.0));
533 }
534
535 pub fn test_gemv<M: Matrix>() {
536 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
537 let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
538 let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
539 let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(2.0)], Default::default());
540 let mut y = M::V::zeros(2, Default::default());
541 a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
542 assert_eq!(y.clone_as_vec(), vec![f::<M>(5.0), f::<M>(11.0)]);
543 }
544
545 pub fn test_set_column<M: Matrix>() {
546 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
547 let values = vec![f::<M>(0.0), f::<M>(0.0), f::<M>(0.0), f::<M>(0.0)];
548 let mut a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
549 let v = M::V::from_vec(vec![f::<M>(7.0), f::<M>(8.0)], Default::default());
550 a.set_column(1, &v);
551 let idx = triplet_indices(&a);
552 let vals = triplet_values(&a);
553 assert_eq!(idx, vec![(0, 0), (1, 0), (0, 1), (1, 1)]);
554 assert_eq!(
555 vals,
556 vec![f::<M>(0.0), f::<M>(0.0), f::<M>(7.0), f::<M>(8.0)]
557 );
558 }
559
560 pub fn test_copy_from<M: Matrix>() {
561 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
562 let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
563 let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
564 let mut b = M::zeros(2, 2, Default::default());
565 b.copy_from(&a);
566 let vals = triplet_values(&b);
567 assert_eq!(
568 vals,
569 vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)]
570 );
571 }
572
573 pub fn test_scale_add_and_assign<M: Matrix>() {
574 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
575 let x_vals = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
576 let y_vals = vec![f::<M>(10.0), f::<M>(20.0), f::<M>(30.0), f::<M>(40.0)];
577 let x = M::try_from_triplets(2, 2, indices.clone(), x_vals, Default::default()).unwrap();
578 let y = M::try_from_triplets(2, 2, indices, y_vals, Default::default()).unwrap();
579 let mut result = M::zeros(2, 2, Default::default());
580 result.copy_from(&x);
581 result.scale_add_and_assign(&x, f::<M>(2.0), &y);
582 let vals = triplet_values(&result);
583 assert_eq!(
584 vals,
585 vec![f::<M>(21.0), f::<M>(42.0), f::<M>(63.0), f::<M>(84.0)]
586 );
587 }
588
589 pub fn test_column_axpy<M: DenseMatrix>() {
592 let mut a = M::zeros(2, 2, Default::default());
593 a.set_index(0, 0, M::T::one());
594 a.set_index(0, 1, M::T::from_f64(2.0).unwrap());
595 a.set_index(1, 0, M::T::from_f64(3.0).unwrap());
596 a.set_index(1, 1, M::T::from_f64(4.0).unwrap());
597
598 a.column_axpy(M::T::from_f64(2.0).unwrap(), 0, 1);
599 assert_eq!(a.get_index(0, 0), M::T::one());
600 assert_eq!(a.get_index(0, 1), M::T::from_f64(4.0).unwrap());
601 assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
602 assert_eq!(a.get_index(1, 1), M::T::from_f64(10.0).unwrap());
603 }
604
605 pub fn test_resize_cols<M: DenseMatrix>() {
606 let mut a = M::zeros(2, 2, Default::default());
607 a.set_index(0, 0, M::T::one());
608 a.set_index(0, 1, M::T::from_f64(2.0).unwrap());
609 a.set_index(1, 0, M::T::from_f64(3.0).unwrap());
610 a.set_index(1, 1, M::T::from_f64(4.0).unwrap());
611
612 a.resize_cols(3);
613 assert_eq!(a.ncols(), 3);
614 assert_eq!(a.nrows(), 2);
615 assert_eq!(a.get_index(0, 0), M::T::one());
616 assert_eq!(a.get_index(0, 1), M::T::from_f64(2.0).unwrap());
617 assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
618 assert_eq!(a.get_index(1, 1), M::T::from_f64(4.0).unwrap());
619
620 a.set_index(0, 2, M::T::from_f64(5.0).unwrap());
621 a.set_index(1, 2, M::T::from_f64(6.0).unwrap());
622 assert_eq!(a.get_index(0, 2), M::T::from_f64(5.0).unwrap());
623 assert_eq!(a.get_index(1, 2), M::T::from_f64(6.0).unwrap());
624
625 a.resize_cols(2);
626 assert_eq!(a.ncols(), 2);
627 assert_eq!(a.nrows(), 2);
628 assert_eq!(a.get_index(0, 0), M::T::one());
629 assert_eq!(a.get_index(0, 1), M::T::from_f64(2.0).unwrap());
630 assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
631 assert_eq!(a.get_index(1, 1), M::T::from_f64(4.0).unwrap());
632 }
633
634 pub fn test_from_vec<M: DenseMatrix>() {
635 let a = M::from_vec(
636 2,
637 2,
638 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
639 Default::default(),
640 );
641 assert_eq!(a.nrows(), 2);
642 assert_eq!(a.ncols(), 2);
643 assert_eq!(a.get_index(0, 0), f::<M>(1.0));
644 assert_eq!(a.get_index(1, 0), f::<M>(3.0));
645 assert_eq!(a.get_index(0, 1), f::<M>(2.0));
646 assert_eq!(a.get_index(1, 1), f::<M>(4.0));
647 }
648
649 pub fn test_gemm<M: DenseMatrix>() {
650 let a = M::from_vec(
651 2,
652 2,
653 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
654 Default::default(),
655 );
656 let b = M::from_vec(
657 2,
658 2,
659 vec![f::<M>(2.0), f::<M>(1.0), f::<M>(0.0), f::<M>(3.0)],
660 Default::default(),
661 );
662 let mut c = M::zeros(2, 2, Default::default());
663 c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
664 assert_eq!(c.get_index(0, 0), f::<M>(4.0));
665 assert_eq!(c.get_index(1, 0), f::<M>(10.0));
666 assert_eq!(c.get_index(0, 1), f::<M>(6.0));
667 assert_eq!(c.get_index(1, 1), f::<M>(12.0));
668 }
669
670 pub fn test_mat_mul<M: DenseMatrix>() {
671 let a = M::from_vec(
672 2,
673 2,
674 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
675 Default::default(),
676 );
677 let b = M::from_vec(
678 2,
679 2,
680 vec![f::<M>(2.0), f::<M>(1.0), f::<M>(0.0), f::<M>(3.0)],
681 Default::default(),
682 );
683 let c = a.mat_mul(&b);
684 assert_eq!(c.get_index(0, 0), f::<M>(4.0));
685 assert_eq!(c.get_index(1, 0), f::<M>(10.0));
686 assert_eq!(c.get_index(0, 1), f::<M>(6.0));
687 assert_eq!(c.get_index(1, 1), f::<M>(12.0));
688 }
689
690 pub fn test_columns_view<M: DenseMatrix>() {
691 let a = M::from_vec(
692 2,
693 3,
694 vec![
695 f::<M>(1.0),
696 f::<M>(4.0),
697 f::<M>(2.0),
698 f::<M>(5.0),
699 f::<M>(3.0),
700 f::<M>(6.0),
701 ],
702 Default::default(),
703 );
704 let view = a.columns(0, 2);
705 assert_eq!(view.ncols(), 2);
706 assert_eq!(view.nrows(), 2);
707 let owned = view.into_owned();
708 assert_eq!(owned.get_index(0, 0), f::<M>(1.0));
709 assert_eq!(owned.get_index(1, 0), f::<M>(4.0));
710 assert_eq!(owned.get_index(0, 1), f::<M>(2.0));
711 assert_eq!(owned.get_index(1, 1), f::<M>(5.0));
712 }
713
714 pub fn test_column_view<M: DenseMatrix>() {
715 let a = M::from_vec(
716 2,
717 2,
718 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
719 Default::default(),
720 );
721 let col = a.column(1);
722 use crate::VectorView;
723 assert_eq!(col.get_index(0), f::<M>(2.0));
724 assert_eq!(col.get_index(1), f::<M>(4.0));
725 }
726
727 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
730 pub fn test_batched_zeros_m<M: Matrix>(ctx: M::C) {
731 assert_eq!(ctx.nbatch(), 2);
732 let a = M::zeros(2, 3, ctx);
733 assert_eq!(a.nrows(), 2);
734 assert_eq!(a.ncols(), 3);
735 let vals = triplet_values(&a);
736 assert!(vals.is_empty() || vals.iter().all(|v| v.is_zero()));
737 }
738
739 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
740 pub fn test_batched_gemv_m<M: Matrix>(ctx: M::C) {
741 assert_eq!(ctx.nbatch(), 2);
742 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
743 let values = vec![
744 f::<M>(1.0),
745 f::<M>(3.0),
746 f::<M>(2.0),
747 f::<M>(4.0), f::<M>(5.0),
749 f::<M>(7.0),
750 f::<M>(6.0),
751 f::<M>(8.0), ];
753 let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
754 let x = M::V::from_vec(
755 vec![f::<M>(1.0), f::<M>(2.0), f::<M>(1.0), f::<M>(1.0)],
756 ctx.clone(),
757 );
758 let mut y = M::V::zeros(2, ctx);
759 a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
760 assert_eq!(
761 y.clone_as_vec(),
762 vec![f::<M>(5.0), f::<M>(11.0), f::<M>(11.0), f::<M>(15.0)]
763 );
764 }
765
766 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
767 pub fn test_batched_gemv_broadcast_x_m<M: Matrix>(ctx: M::C) {
768 assert_eq!(ctx.nbatch(), 2);
769 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
770 let values = vec![
771 f::<M>(1.0),
772 f::<M>(3.0),
773 f::<M>(2.0),
774 f::<M>(4.0),
775 f::<M>(5.0),
776 f::<M>(7.0),
777 f::<M>(6.0),
778 f::<M>(8.0),
779 ];
780 let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
781 let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(2.0)], Default::default());
782 let mut y = M::V::zeros(2, ctx);
783 a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
784 assert_eq!(
785 y.clone_as_vec(),
786 vec![f::<M>(5.0), f::<M>(11.0), f::<M>(17.0), f::<M>(23.0)]
787 );
788 }
789
790 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
791 pub fn test_batched_gemv_broadcast_mat_m<M: Matrix>(ctx: M::C) {
792 assert_eq!(ctx.nbatch(), 2);
793 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
794 let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
795 let a =
796 M::try_from_triplets(2, 2, indices, values, ctx.clone_with_nbatch(1).unwrap()).unwrap();
797 let x = M::V::from_vec(
798 vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)],
799 ctx.clone(),
800 );
801 let mut y = M::V::zeros(2, ctx);
802 a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
803 assert_eq!(
804 y.clone_as_vec(),
805 vec![f::<M>(5.0), f::<M>(11.0), f::<M>(11.0), f::<M>(25.0)]
806 );
807 }
808
809 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
810 pub fn test_batched_from_diagonal_m<M: Matrix>(ctx: M::C) {
811 assert_eq!(ctx.nbatch(), 2);
812 let v = M::V::from_vec(
813 vec![f::<M>(2.0), f::<M>(3.0), f::<M>(4.0), f::<M>(5.0)],
814 ctx,
815 );
816 let a = M::from_diagonal(&v);
817 assert_eq!(a.nrows(), 2);
818 assert_eq!(a.ncols(), 2);
819 let idx = triplet_indices(&a);
820 let vals = triplet_values(&a);
821 for &(i, j) in &idx {
822 let pos = idx.iter().position(|&x| x == (i, j)).unwrap();
823 if i == j {
824 assert!(
825 vals[pos] != M::T::zero(),
826 "diagonal entry should be non-zero"
827 );
828 } else {
829 assert!(vals[pos].is_zero(), "off-diagonal entry should be zero");
830 }
831 }
832 }
833
834 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
835 pub fn test_batched_copy_from_m<M: Matrix>(ctx: M::C) {
836 assert_eq!(ctx.nbatch(), 2);
837 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
838 let values = vec![
839 f::<M>(1.0),
840 f::<M>(2.0),
841 f::<M>(3.0),
842 f::<M>(4.0),
843 f::<M>(5.0),
844 f::<M>(6.0),
845 f::<M>(7.0),
846 f::<M>(8.0),
847 ];
848 let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
849 let mut b = M::zeros(2, 2, ctx);
850 b.copy_from(&a);
851 let vals = triplet_values(&b);
852 assert_eq!(
853 vals,
854 vec![
855 f::<M>(1.0),
856 f::<M>(2.0),
857 f::<M>(3.0),
858 f::<M>(4.0),
859 f::<M>(5.0),
860 f::<M>(6.0),
861 f::<M>(7.0),
862 f::<M>(8.0),
863 ]
864 );
865 }
866
867 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
868 pub fn test_batched_set_column_m<M: Matrix>(ctx: M::C) {
869 assert_eq!(ctx.nbatch(), 2);
870 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
871 let values = vec![
872 f::<M>(0.0),
873 f::<M>(0.0),
874 f::<M>(0.0),
875 f::<M>(0.0),
876 f::<M>(0.0),
877 f::<M>(0.0),
878 f::<M>(0.0),
879 f::<M>(0.0),
880 ];
881 let mut a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
882 let v = M::V::from_vec(
883 vec![f::<M>(5.0), f::<M>(6.0), f::<M>(7.0), f::<M>(8.0)],
884 ctx,
885 );
886 a.set_column(0, &v);
887 let vals = triplet_values(&a);
888 assert_eq!(
889 vals,
890 vec![
891 f::<M>(5.0),
892 f::<M>(6.0),
893 f::<M>(0.0),
894 f::<M>(0.0),
895 f::<M>(7.0),
896 f::<M>(8.0),
897 f::<M>(0.0),
898 f::<M>(0.0),
899 ]
900 );
901 }
902
903 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
904 pub fn test_batched_scale_add_and_assign_m<M: Matrix>(ctx: M::C) {
905 assert_eq!(ctx.nbatch(), 2);
906 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
907 let x_vals = vec![
908 f::<M>(1.0),
909 f::<M>(2.0),
910 f::<M>(3.0),
911 f::<M>(4.0),
912 f::<M>(5.0),
913 f::<M>(6.0),
914 f::<M>(7.0),
915 f::<M>(8.0),
916 ];
917 let y_vals = vec![
918 f::<M>(10.0),
919 f::<M>(20.0),
920 f::<M>(30.0),
921 f::<M>(40.0),
922 f::<M>(50.0),
923 f::<M>(60.0),
924 f::<M>(70.0),
925 f::<M>(80.0),
926 ];
927 let x = M::try_from_triplets(2, 2, indices.clone(), x_vals, ctx.clone()).unwrap();
928 let y = M::try_from_triplets(2, 2, indices, y_vals, ctx.clone()).unwrap();
929 let mut result = M::zeros(2, 2, ctx);
930 result.copy_from(&x);
931 result.scale_add_and_assign(&x, f::<M>(2.0), &y);
932 let vals = triplet_values(&result);
933 assert_eq!(
934 vals,
935 vec![
936 f::<M>(21.0),
937 f::<M>(42.0),
938 f::<M>(63.0),
939 f::<M>(84.0),
940 f::<M>(105.0),
941 f::<M>(126.0),
942 f::<M>(147.0),
943 f::<M>(168.0),
944 ]
945 );
946 }
947
948 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
951 pub fn test_batched_from_vec<M: DenseMatrix>(ctx: M::C) {
952 assert_eq!(ctx.nbatch(), 2);
953 let a = M::from_vec(
956 2,
957 2,
958 vec![
959 f::<M>(1.0),
960 f::<M>(3.0),
961 f::<M>(2.0),
962 f::<M>(4.0),
963 f::<M>(5.0),
964 f::<M>(7.0),
965 f::<M>(6.0),
966 f::<M>(8.0),
967 ],
968 ctx,
969 );
970 assert_eq!(a.nrows(), 2);
971 assert_eq!(a.ncols(), 2);
972 assert_eq!(a.get_index(0, 0), f::<M>(1.0));
973 assert_eq!(a.get_index(1, 0), f::<M>(3.0));
974 assert_eq!(a.get_index(0, 1), f::<M>(2.0));
975 assert_eq!(a.get_index(1, 1), f::<M>(4.0));
976 }
977
978 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
979 pub fn test_batched_gemm<M: DenseMatrix>(ctx: M::C) {
980 assert_eq!(ctx.nbatch(), 2);
981 let a = M::from_vec(
983 2,
984 2,
985 vec![
986 f::<M>(1.0),
987 f::<M>(0.0),
988 f::<M>(0.0),
989 f::<M>(1.0),
990 f::<M>(2.0),
991 f::<M>(0.0),
992 f::<M>(0.0),
993 f::<M>(2.0),
994 ],
995 ctx.clone(),
996 );
997 let b = M::from_vec(
999 2,
1000 2,
1001 vec![
1002 f::<M>(3.0),
1003 f::<M>(5.0),
1004 f::<M>(4.0),
1005 f::<M>(6.0),
1006 f::<M>(1.0),
1007 f::<M>(1.0),
1008 f::<M>(1.0),
1009 f::<M>(1.0),
1010 ],
1011 ctx.clone(),
1012 );
1013 let mut c = M::zeros(2, 2, ctx);
1014 c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1015 assert_eq!(c.get_index(0, 0), f::<M>(3.0));
1017 assert_eq!(c.get_index(1, 0), f::<M>(5.0));
1018 assert_eq!(c.get_index(0, 1), f::<M>(4.0));
1019 assert_eq!(c.get_index(1, 1), f::<M>(6.0));
1020 }
1021
1022 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1023 pub fn test_batched_columns<M: DenseMatrix>(ctx: M::C) {
1024 assert_eq!(ctx.nbatch(), 2);
1025 let a = M::from_vec(
1028 2,
1029 3,
1030 vec![
1031 f::<M>(1.0),
1032 f::<M>(2.0),
1033 f::<M>(3.0),
1034 f::<M>(4.0),
1035 f::<M>(5.0),
1036 f::<M>(6.0),
1037 f::<M>(7.0),
1038 f::<M>(8.0),
1039 f::<M>(9.0),
1040 f::<M>(10.0),
1041 f::<M>(11.0),
1042 f::<M>(12.0),
1043 ],
1044 ctx.clone(),
1045 );
1046 let view = a.columns(0, 2);
1047 assert_eq!(view.ncols(), 2);
1048 assert_eq!(view.nrows(), 2);
1049 let owned = view.into_owned();
1050 assert_eq!(owned.nrows(), 2);
1051 assert_eq!(owned.ncols(), 2);
1052 let view2 = a.columns(0, 2);
1054 let x = M::V::from_vec(
1055 vec![f::<M>(1.0), f::<M>(1.0), f::<M>(1.0), f::<M>(1.0)],
1056 ctx.clone(),
1057 );
1058 let mut y = M::V::zeros(2, ctx);
1059 view2.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1060 assert_eq!(
1062 y.clone_as_vec(),
1063 vec![f::<M>(4.0), f::<M>(6.0), f::<M>(16.0), f::<M>(18.0)]
1064 );
1065 }
1066
1067 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1068 pub fn test_batched_gemv_o_on_columns<M: DenseMatrix>(ctx: M::C) {
1069 assert_eq!(ctx.nbatch(), 2);
1070 let diff = M::from_vec(
1073 2,
1074 3,
1075 vec![
1076 f::<M>(1.0),
1077 f::<M>(4.0),
1078 f::<M>(2.0),
1079 f::<M>(5.0),
1080 f::<M>(3.0),
1081 f::<M>(6.0),
1082 f::<M>(7.0),
1083 f::<M>(10.0),
1084 f::<M>(8.0),
1085 f::<M>(11.0),
1086 f::<M>(9.0),
1087 f::<M>(12.0),
1088 ],
1089 ctx.clone(),
1090 );
1091 let view = diff.columns(0, 2);
1093 let x = M::V::from_vec(
1096 vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
1097 ctx.clone(),
1098 );
1099 let mut y = M::V::zeros(2, ctx);
1100 view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1101 assert_eq!(
1104 y.clone_as_vec(),
1105 vec![f::<M>(3.0), f::<M>(9.0), f::<M>(30.0), f::<M>(42.0)]
1106 );
1107 }
1108
1109 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1110 pub fn test_batched_gemv_v_broadcast_mat<M: DenseMatrix>(ctx3: M::C) {
1111 assert_eq!(ctx3.nbatch(), 2);
1112 let ctx1 = M::C::default();
1114 let diff = M::from_vec(
1116 2,
1117 3,
1118 vec![
1119 f::<M>(1.0),
1120 f::<M>(4.0),
1121 f::<M>(2.0),
1122 f::<M>(5.0),
1123 f::<M>(3.0),
1124 f::<M>(6.0),
1125 ],
1126 ctx1,
1127 );
1128 let view = diff.columns(0, 2);
1129 let x = M::V::from_vec(
1132 vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
1133 ctx3.clone(),
1134 );
1135 let mut y = M::V::zeros(2, ctx3);
1136 view.gemv_v(f::<M>(1.0), &x.as_view(), f::<M>(0.0), &mut y);
1137 assert_eq!(
1140 y.clone_as_vec(),
1141 vec![f::<M>(3.0), f::<M>(9.0), f::<M>(6.0), f::<M>(18.0)]
1142 );
1143 }
1144
1145 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1146 pub fn test_batched_gemv_o_broadcast_mat<M: DenseMatrix>(ctx3: M::C) {
1147 assert_eq!(ctx3.nbatch(), 2);
1148 let ctx1 = M::C::default();
1150 let diff = M::from_vec(
1152 2,
1153 3,
1154 vec![
1155 f::<M>(1.0),
1156 f::<M>(4.0),
1157 f::<M>(2.0),
1158 f::<M>(5.0),
1159 f::<M>(3.0),
1160 f::<M>(6.0),
1161 ],
1162 ctx1,
1163 );
1164 let view = diff.columns(0, 2);
1165 let x = M::V::from_vec(
1168 vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
1169 ctx3.clone(),
1170 );
1171 let mut y = M::V::zeros(2, ctx3);
1172 view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1173 assert_eq!(
1176 y.clone_as_vec(),
1177 vec![f::<M>(3.0), f::<M>(9.0), f::<M>(6.0), f::<M>(18.0)]
1178 );
1179 }
1180
1181 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1182 pub fn test_batched_gemm_vo_on_columns<M: DenseMatrix>(ctx: M::C) {
1183 assert_eq!(ctx.nbatch(), 2);
1184 let diff = M::from_vec(
1186 2,
1187 3,
1188 vec![
1189 f::<M>(1.0),
1190 f::<M>(4.0),
1191 f::<M>(2.0),
1192 f::<M>(5.0),
1193 f::<M>(3.0),
1194 f::<M>(6.0),
1195 f::<M>(7.0),
1196 f::<M>(10.0),
1197 f::<M>(8.0),
1198 f::<M>(11.0),
1199 f::<M>(9.0),
1200 f::<M>(12.0),
1201 ],
1202 ctx.clone(),
1203 );
1204 let r = M::from_vec(
1206 2,
1207 2,
1208 vec![
1209 f::<M>(1.0),
1210 f::<M>(0.0),
1211 f::<M>(0.0),
1212 f::<M>(1.0),
1213 f::<M>(2.0),
1214 f::<M>(0.0),
1215 f::<M>(0.0),
1216 f::<M>(2.0),
1217 ],
1218 ctx.clone(),
1219 );
1220 let mut result = M::zeros(2, 3, ctx);
1221 {
1222 let d_view = diff.columns(0, 2);
1223 let mut r_view = result.columns_mut(0, 2);
1224 r_view.gemm_vo(f::<M>(1.0), &d_view, &r, f::<M>(0.0));
1225 }
1226 assert_eq!(result.get_index(0, 0), f::<M>(1.0));
1229 assert_eq!(result.get_index(1, 0), f::<M>(4.0));
1230 assert_eq!(result.get_index(0, 1), f::<M>(2.0));
1231 assert_eq!(result.get_index(1, 1), f::<M>(5.0));
1232 }
1233
1234 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1237 pub fn test_batched_gemm_broadcast_b<M: DenseMatrix>(ctx: M::C) {
1238 assert_eq!(ctx.nbatch(), 2);
1239 let a = M::from_vec(
1241 2,
1242 2,
1243 vec![
1244 f::<M>(1.0),
1245 f::<M>(0.0),
1246 f::<M>(0.0),
1247 f::<M>(1.0),
1248 f::<M>(2.0),
1249 f::<M>(0.0),
1250 f::<M>(0.0),
1251 f::<M>(3.0),
1252 ],
1253 ctx.clone(),
1254 );
1255 let b = M::from_vec(
1257 2,
1258 2,
1259 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1260 Default::default(),
1261 );
1262 let mut c = M::zeros(2, 2, ctx);
1263 c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1264 assert_eq!(c.get_index(0, 0), f::<M>(1.0));
1266 assert_eq!(c.get_index(1, 0), f::<M>(3.0));
1267 assert_eq!(c.get_index(0, 1), f::<M>(2.0));
1268 assert_eq!(c.get_index(1, 1), f::<M>(4.0));
1269 }
1270
1271 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1272 pub fn test_batched_gemm_broadcast_a<M: DenseMatrix>(ctx: M::C) {
1273 assert_eq!(ctx.nbatch(), 2);
1274 let a = M::from_vec(
1276 2,
1277 2,
1278 vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(2.0)],
1279 Default::default(),
1280 );
1281 let b = M::from_vec(
1283 2,
1284 2,
1285 vec![
1286 f::<M>(3.0),
1287 f::<M>(5.0),
1288 f::<M>(4.0),
1289 f::<M>(6.0),
1290 f::<M>(1.0),
1291 f::<M>(1.0),
1292 f::<M>(1.0),
1293 f::<M>(1.0),
1294 ],
1295 ctx.clone(),
1296 );
1297 let mut c = M::zeros(2, 2, ctx);
1298 c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1299 assert_eq!(c.get_index(0, 0), f::<M>(3.0));
1301 assert_eq!(c.get_index(1, 0), f::<M>(10.0));
1302 assert_eq!(c.get_index(0, 1), f::<M>(4.0));
1303 assert_eq!(c.get_index(1, 1), f::<M>(12.0));
1304 }
1305
1306 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1307 pub fn test_batched_gemv_o_broadcast_x<M: DenseMatrix>(ctx: M::C) {
1308 assert_eq!(ctx.nbatch(), 2);
1309 let diff = M::from_vec(
1311 2,
1312 3,
1313 vec![
1314 f::<M>(1.0),
1315 f::<M>(4.0),
1316 f::<M>(2.0),
1317 f::<M>(5.0),
1318 f::<M>(3.0),
1319 f::<M>(6.0),
1320 f::<M>(7.0),
1321 f::<M>(10.0),
1322 f::<M>(8.0),
1323 f::<M>(11.0),
1324 f::<M>(9.0),
1325 f::<M>(12.0),
1326 ],
1327 ctx.clone(),
1328 );
1329 let view = diff.columns(0, 2);
1330 let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(1.0)], Default::default());
1332 let mut y = M::V::zeros(2, ctx);
1333 view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1334 assert_eq!(
1337 y.clone_as_vec(),
1338 vec![f::<M>(3.0), f::<M>(9.0), f::<M>(15.0), f::<M>(21.0)]
1339 );
1340 }
1341
1342 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1343 pub fn test_batched_gemm_vo_broadcast_b<M: DenseMatrix>(ctx: M::C) {
1344 assert_eq!(ctx.nbatch(), 2);
1345 let diff = M::from_vec(
1347 2,
1348 3,
1349 vec![
1350 f::<M>(1.0),
1351 f::<M>(4.0),
1352 f::<M>(2.0),
1353 f::<M>(5.0),
1354 f::<M>(3.0),
1355 f::<M>(6.0),
1356 f::<M>(7.0),
1357 f::<M>(10.0),
1358 f::<M>(8.0),
1359 f::<M>(11.0),
1360 f::<M>(9.0),
1361 f::<M>(12.0),
1362 ],
1363 ctx.clone(),
1364 );
1365 let r = M::from_vec(
1367 2,
1368 2,
1369 vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(1.0)],
1370 Default::default(),
1371 );
1372 let mut result = M::zeros(2, 3, ctx);
1373 {
1374 let d_view = diff.columns(0, 2);
1375 let mut r_view = result.columns_mut(0, 2);
1376 r_view.gemm_vo(f::<M>(1.0), &d_view, &r, f::<M>(0.0));
1377 }
1378 assert_eq!(result.get_index(0, 0), f::<M>(1.0));
1381 assert_eq!(result.get_index(1, 0), f::<M>(4.0));
1382 assert_eq!(result.get_index(0, 1), f::<M>(2.0));
1383 assert_eq!(result.get_index(1, 1), f::<M>(5.0));
1384 }
1385
1386 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1387 pub fn test_batched_gemm_vo_broadcast_a<M: DenseMatrix>(ctx: M::C) {
1388 assert_eq!(ctx.nbatch(), 2);
1389 let diff = M::from_vec(
1391 2,
1392 3,
1393 vec![
1394 f::<M>(1.0),
1395 f::<M>(4.0),
1396 f::<M>(2.0),
1397 f::<M>(5.0),
1398 f::<M>(3.0),
1399 f::<M>(6.0),
1400 ],
1401 Default::default(),
1402 );
1403 let b = M::from_vec(
1405 2,
1406 2,
1407 vec![
1408 f::<M>(1.0),
1409 f::<M>(0.0),
1410 f::<M>(0.0),
1411 f::<M>(1.0),
1412 f::<M>(2.0),
1413 f::<M>(0.0),
1414 f::<M>(0.0),
1415 f::<M>(3.0),
1416 ],
1417 ctx.clone(),
1418 );
1419 let mut result = M::zeros(2, 3, ctx);
1420 {
1421 let d_view = diff.columns(0, 2);
1422 let mut r_view = result.columns_mut(0, 2);
1423 r_view.gemm_vo(f::<M>(1.0), &d_view, &b, f::<M>(0.0));
1424 }
1425 assert_eq!(result.get_index(0, 0), f::<M>(1.0));
1427 assert_eq!(result.get_index(1, 0), f::<M>(4.0));
1428 assert_eq!(result.get_index(0, 1), f::<M>(2.0));
1429 assert_eq!(result.get_index(1, 1), f::<M>(5.0));
1430 }
1431
1432 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1435 pub fn test_batched_gemm_incompatible_a<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
1436 assert_eq!(ctx2.nbatch(), 2);
1437 assert_eq!(ctx3.nbatch(), 3);
1438 let a = M::zeros(2, 2, ctx3);
1439 let b = M::zeros(2, 2, ctx2.clone());
1440 let mut c = M::zeros(2, 2, ctx2);
1441 c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1442 }
1443
1444 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1445 pub fn test_batched_gemv_incompatible<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
1446 assert_eq!(ctx2.nbatch(), 2);
1447 assert_eq!(ctx3.nbatch(), 3);
1448 let a = M::zeros(2, 2, ctx2.clone());
1449 let x = M::V::zeros(2, ctx3);
1450 let mut y = M::V::zeros(2, ctx2);
1451 a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1452 }
1453
1454 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1455 pub fn test_batched_gemm_incompatible<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
1456 assert_eq!(ctx2.nbatch(), 2);
1457 assert_eq!(ctx3.nbatch(), 3);
1458 let a = M::zeros(2, 2, ctx2.clone());
1459 let b = M::zeros(2, 2, ctx3);
1460 let mut c = M::zeros(2, 2, ctx2);
1461 c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1462 }
1463
1464 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1465 pub fn test_batched_resize_cols<M: DenseMatrix>(ctx: M::C) {
1466 assert_eq!(ctx.nbatch(), 2);
1467 let mut a = M::from_vec(
1469 2,
1470 2,
1471 vec![
1472 f::<M>(1.0),
1473 f::<M>(3.0),
1474 f::<M>(2.0),
1475 f::<M>(4.0),
1476 f::<M>(5.0),
1477 f::<M>(7.0),
1478 f::<M>(6.0),
1479 f::<M>(8.0),
1480 ],
1481 ctx.clone(),
1482 );
1483 a.resize_cols(3);
1485 assert_eq!(a.ncols(), 3);
1486 assert_eq!(a.nrows(), 2);
1487 assert_eq!(a.get_index(0, 0), f::<M>(1.0));
1489 assert_eq!(a.get_index(1, 0), f::<M>(3.0));
1490 assert_eq!(a.get_index(0, 1), f::<M>(2.0));
1491 assert_eq!(a.get_index(1, 1), f::<M>(4.0));
1492 assert_eq!(a.get_index(0, 2), f::<M>(0.0));
1494 assert_eq!(a.get_index(1, 2), f::<M>(0.0));
1495 let x = M::V::from_vec(
1497 vec![
1498 f::<M>(1.0),
1499 f::<M>(0.0),
1500 f::<M>(0.0),
1501 f::<M>(1.0),
1502 f::<M>(0.0),
1503 f::<M>(0.0),
1504 ],
1505 ctx.clone(),
1506 );
1507 let mut y = M::V::zeros(2, ctx.clone());
1508 a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1509 assert_eq!(
1512 y.clone_as_vec(),
1513 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(5.0), f::<M>(7.0)]
1514 );
1515
1516 a.resize_cols(1);
1518 assert_eq!(a.ncols(), 1);
1519 assert_eq!(a.get_index(0, 0), f::<M>(1.0));
1520 assert_eq!(a.get_index(1, 0), f::<M>(3.0));
1521 let x2 = M::V::from_vec(vec![f::<M>(1.0), f::<M>(1.0)], ctx.clone());
1523 let mut y2 = M::V::zeros(2, ctx);
1524 a.gemv(f::<M>(1.0), &x2, f::<M>(0.0), &mut y2);
1525 assert_eq!(
1527 y2.clone_as_vec(),
1528 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(5.0), f::<M>(7.0)]
1529 );
1530 }
1531
1532 pub fn test_mul_scalar<M: Matrix>() {
1535 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1536 let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
1537 let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
1538 let result = a * Scale(f::<M>(2.0));
1539 let (_, vals) = result.triplet_iter();
1540 let vals: Vec<_> = vals.collect();
1541 assert_eq!(
1542 vals,
1543 vec![f::<M>(2.0), f::<M>(6.0), f::<M>(4.0), f::<M>(8.0)]
1544 );
1545 }
1546
1547 pub fn test_add_column_to_vector<M: Matrix>() {
1548 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1549 let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
1550 let mat = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
1551 let mut v = M::V::zeros(2, Default::default());
1552 mat.add_column_to_vector(1, &mut v);
1553 assert_eq!(v.clone_as_vec(), vec![f::<M>(3.0), f::<M>(4.0)]);
1554 }
1555
1556 pub fn test_add<M: DenseMatrix>() {
1559 let a = M::from_vec(
1560 2,
1561 2,
1562 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1563 Default::default(),
1564 );
1565 let b = M::from_vec(
1566 2,
1567 2,
1568 vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
1569 Default::default(),
1570 );
1571 let result = a + &b;
1572 assert_eq!(result.get_index(0, 0), f::<M>(6.0));
1573 assert_eq!(result.get_index(1, 1), f::<M>(12.0));
1574 }
1575
1576 pub fn test_sub<M: DenseMatrix>() {
1577 let a = M::from_vec(
1578 2,
1579 2,
1580 vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
1581 Default::default(),
1582 );
1583 let b = M::from_vec(
1584 2,
1585 2,
1586 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1587 Default::default(),
1588 );
1589 let result = a - &b;
1590 assert_eq!(result.get_index(0, 0), f::<M>(4.0));
1591 assert_eq!(result.get_index(1, 1), f::<M>(4.0));
1592 }
1593
1594 pub fn test_add_assign<M: DenseMatrix>() {
1595 let mut a = M::from_vec(
1596 2,
1597 2,
1598 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1599 Default::default(),
1600 );
1601 let b = M::from_vec(
1602 2,
1603 2,
1604 vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
1605 Default::default(),
1606 );
1607 a += &b;
1608 assert_eq!(a.get_index(0, 0), f::<M>(6.0));
1609 assert_eq!(a.get_index(1, 1), f::<M>(12.0));
1610 }
1611
1612 pub fn test_sub_assign<M: DenseMatrix>() {
1613 let mut a = M::from_vec(
1614 2,
1615 2,
1616 vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
1617 Default::default(),
1618 );
1619 let b = M::from_vec(
1620 2,
1621 2,
1622 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1623 Default::default(),
1624 );
1625 a -= &b;
1626 assert_eq!(a.get_index(0, 0), f::<M>(4.0));
1627 assert_eq!(a.get_index(1, 1), f::<M>(4.0));
1628 }
1629
1630 pub fn test_gather<M: DenseMatrix>() {
1631 let mat1 = M::from_vec(
1632 3,
1633 3,
1634 vec![
1635 f::<M>(1.0),
1636 f::<M>(2.0),
1637 f::<M>(3.0),
1638 f::<M>(4.0),
1639 f::<M>(5.0),
1640 f::<M>(6.0),
1641 f::<M>(7.0),
1642 f::<M>(8.0),
1643 f::<M>(9.0),
1644 ],
1645 Default::default(),
1646 );
1647 let mut mat2 = M::zeros(2, 2, Default::default());
1648 let indices = <M::V as Vector>::Index::from_vec(vec![0, 1, 3, 4], Default::default());
1649 mat2.gather(&mat1, &indices);
1650 assert_eq!(mat2.get_index(0, 0), f::<M>(1.0));
1651 assert_eq!(mat2.get_index(1, 0), f::<M>(2.0));
1652 assert_eq!(mat2.get_index(0, 1), f::<M>(4.0));
1653 assert_eq!(mat2.get_index(1, 1), f::<M>(5.0));
1654 }
1655
1656 pub fn test_set_data_with_indices<M: DenseMatrix>() {
1657 let mut mat = M::zeros(2, 2, Default::default());
1658 let dst_indices = <M::V as Vector>::Index::from_vec(vec![0, 3], Default::default());
1659 let src_indices = <M::V as Vector>::Index::from_vec(vec![0, 1], Default::default());
1660 let data = M::V::from_vec(vec![f::<M>(5.0), f::<M>(6.0)], Default::default());
1661 mat.set_data_with_indices(&dst_indices, &src_indices, &data);
1662 assert_eq!(mat.get_index(0, 0), f::<M>(5.0));
1663 assert_eq!(mat.get_index(1, 1), f::<M>(6.0));
1664 }
1665
1666 pub fn test_mul_assign_scalar<M: DenseMatrix>() {
1667 let mut mat = M::from_vec(
1668 2,
1669 2,
1670 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1671 Default::default(),
1672 );
1673 {
1674 let mut view = mat.columns_mut(0, 2);
1675 view *= Scale(f::<M>(2.0));
1676 }
1677 assert_eq!(mat.get_index(0, 0), f::<M>(2.0));
1678 assert_eq!(mat.get_index(1, 1), f::<M>(8.0));
1679 }
1680
1681 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1682 pub fn test_batched_combine<M: DenseMatrix>(ctx: M::C) {
1683 assert_eq!(ctx.nbatch(), 2);
1684 #[rustfmt::skip]
1685 let data: Vec<M::T> = vec![
1686 f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0),
1688 f::<M>(5.0), f::<M>(6.0), f::<M>(7.0), f::<M>(8.0),
1689 f::<M>(9.0), f::<M>(10.0), f::<M>(11.0), f::<M>(12.0),
1690 f::<M>(13.0), f::<M>(14.0), f::<M>(15.0), f::<M>(16.0),
1691 f::<M>(101.0), f::<M>(102.0), f::<M>(103.0), f::<M>(104.0),
1693 f::<M>(105.0), f::<M>(106.0), f::<M>(107.0), f::<M>(108.0),
1694 f::<M>(109.0), f::<M>(110.0), f::<M>(111.0), f::<M>(112.0),
1695 f::<M>(113.0), f::<M>(114.0), f::<M>(115.0), f::<M>(116.0),
1696 ];
1697 let m = M::from_vec(4, 4, data, ctx.clone());
1698
1699 let alg_indices = <M::V as Vector>::Index::from_vec(vec![1, 3], Default::default());
1700 let [(ul, _), (ur, _), (ll, _), (lr, _)] = m.split(&alg_indices);
1701
1702 let recombined = M::combine(&ul, &ur, &ll, &lr, &alg_indices);
1703
1704 let (_orig_idx, orig_vals) = m.triplet_iter();
1705 let (_recom_idx, recom_vals) = recombined.triplet_iter();
1706 let orig_vals: Vec<_> = orig_vals.collect();
1707 let recom_vals: Vec<_> = recom_vals.collect();
1708 assert_eq!(orig_vals, recom_vals);
1709 }
1710
1711 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1712 pub fn test_batched_add_column_to_vector_m<M: Matrix>(ctx: M::C) {
1713 assert_eq!(ctx.nbatch(), 2);
1714 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1715 let values = vec![
1716 f::<M>(1.0),
1717 f::<M>(2.0),
1718 f::<M>(3.0),
1719 f::<M>(4.0),
1720 f::<M>(5.0),
1721 f::<M>(6.0),
1722 f::<M>(7.0),
1723 f::<M>(8.0),
1724 ];
1725 let mat = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
1726 let mut v = M::V::zeros(2, ctx);
1727 mat.add_column_to_vector(1, &mut v);
1728 assert_eq!(
1729 v.clone_as_vec(),
1730 vec![f::<M>(3.0), f::<M>(4.0), f::<M>(7.0), f::<M>(8.0)]
1731 );
1732 }
1733
1734 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1735 pub fn test_batched_set_data_with_indices_m<M: Matrix>(ctx: M::C) {
1736 assert_eq!(ctx.nbatch(), 2);
1737 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1738 let zero_values = vec![
1739 f::<M>(0.0),
1740 f::<M>(0.0),
1741 f::<M>(0.0),
1742 f::<M>(0.0),
1743 f::<M>(0.0),
1744 f::<M>(0.0),
1745 f::<M>(0.0),
1746 f::<M>(0.0),
1747 ];
1748 let mut mat = M::try_from_triplets(2, 2, indices, zero_values, ctx.clone()).unwrap();
1749 let dst_indices = <M::V as Vector>::Index::from_vec(vec![0, 3], Default::default());
1750 let src_indices = <M::V as Vector>::Index::from_vec(vec![0, 1], Default::default());
1751 let data = M::V::from_vec(
1752 vec![f::<M>(5.0), f::<M>(6.0), f::<M>(50.0), f::<M>(60.0)],
1753 ctx,
1754 );
1755 mat.set_data_with_indices(&dst_indices, &src_indices, &data);
1756 let (_, vals) = mat.triplet_iter();
1757 let vals: Vec<_> = vals.collect();
1758 assert_eq!(
1759 vals,
1760 vec![
1761 f::<M>(5.0),
1762 f::<M>(0.0),
1763 f::<M>(0.0),
1764 f::<M>(6.0),
1765 f::<M>(50.0),
1766 f::<M>(0.0),
1767 f::<M>(0.0),
1768 f::<M>(60.0),
1769 ]
1770 );
1771 }
1772
1773 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1774 pub fn test_batched_gather_m<M: Matrix>(ctx: M::C) {
1775 assert_eq!(ctx.nbatch(), 2);
1776 let indices: Vec<(IndexType, IndexType)> =
1777 (0..3).flat_map(|j| (0..3).map(move |i| (i, j))).collect();
1778 let values = vec![
1779 f::<M>(1.0),
1780 f::<M>(2.0),
1781 f::<M>(3.0),
1782 f::<M>(4.0),
1783 f::<M>(5.0),
1784 f::<M>(6.0),
1785 f::<M>(7.0),
1786 f::<M>(8.0),
1787 f::<M>(9.0),
1788 f::<M>(10.0),
1789 f::<M>(20.0),
1790 f::<M>(30.0),
1791 f::<M>(40.0),
1792 f::<M>(50.0),
1793 f::<M>(60.0),
1794 f::<M>(70.0),
1795 f::<M>(80.0),
1796 f::<M>(90.0),
1797 ];
1798 let mat1 = M::try_from_triplets(3, 3, indices, values, ctx.clone()).unwrap();
1799 let dest_indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1800 let zero_values = vec![
1801 f::<M>(0.0),
1802 f::<M>(0.0),
1803 f::<M>(0.0),
1804 f::<M>(0.0),
1805 f::<M>(0.0),
1806 f::<M>(0.0),
1807 f::<M>(0.0),
1808 f::<M>(0.0),
1809 ];
1810 let mut mat2 = M::try_from_triplets(2, 2, dest_indices, zero_values, ctx).unwrap();
1811 let gather_indices =
1812 <M::V as Vector>::Index::from_vec(vec![0, 1, 3, 4], Default::default());
1813 mat2.gather(&mat1, &gather_indices);
1814 let (_, vals) = mat2.triplet_iter();
1815 let vals: Vec<_> = vals.collect();
1816 assert_eq!(
1817 vals,
1818 vec![
1819 f::<M>(1.0),
1820 f::<M>(2.0),
1821 f::<M>(4.0),
1822 f::<M>(5.0),
1823 f::<M>(10.0),
1824 f::<M>(20.0),
1825 f::<M>(40.0),
1826 f::<M>(50.0),
1827 ]
1828 );
1829 }
1830
1831 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1832 pub fn test_batched_mul_scalar_m<M: Matrix>(ctx: M::C) {
1833 assert_eq!(ctx.nbatch(), 2);
1834 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1835 let values = vec![
1836 f::<M>(1.0),
1837 f::<M>(3.0),
1838 f::<M>(2.0),
1839 f::<M>(4.0),
1840 f::<M>(5.0),
1841 f::<M>(7.0),
1842 f::<M>(6.0),
1843 f::<M>(8.0),
1844 ];
1845 let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
1846 let result = a * Scale(f::<M>(2.0));
1847 let (_, vals) = result.triplet_iter();
1848 let vals: Vec<_> = vals.collect();
1849 assert_eq!(
1850 vals,
1851 vec![
1852 f::<M>(2.0),
1853 f::<M>(6.0),
1854 f::<M>(4.0),
1855 f::<M>(8.0),
1856 f::<M>(10.0),
1857 f::<M>(14.0),
1858 f::<M>(12.0),
1859 f::<M>(16.0),
1860 ]
1861 );
1862 }
1863
1864 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1865 pub fn test_batched_partition_indices<M: Matrix>(ctx: M::C) {
1866 assert_eq!(ctx.nbatch(), 2);
1867 let zero_val = M::T::zero();
1868 let one_val = f::<M>(1.0);
1869 let two_val = f::<M>(2.0);
1870 let indices = vec![(0, 0), (1, 1), (2, 2)];
1871 let values = vec![one_val, zero_val, one_val, two_val, zero_val, two_val];
1872 let m = M::try_from_triplets(3, 3, indices, values, ctx).unwrap();
1873 let (zero_idx, nonzero_idx) = m.partition_indices_by_zero_diagonal();
1874 assert_eq!(zero_idx.clone_as_vec(), vec![1]);
1875 assert_eq!(nonzero_idx.clone_as_vec(), vec![0, 2]);
1876 }
1877
1878 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1879 pub fn test_batched_column_axpy<M: DenseMatrix>(ctx: M::C) {
1880 assert_eq!(ctx.nbatch(), 2);
1881 let mut a = M::from_vec(
1882 2,
1883 2,
1884 vec![
1885 f::<M>(1.0),
1886 f::<M>(3.0),
1887 f::<M>(2.0),
1888 f::<M>(4.0),
1889 f::<M>(5.0),
1890 f::<M>(7.0),
1891 f::<M>(6.0),
1892 f::<M>(8.0),
1893 ],
1894 ctx,
1895 );
1896 a.column_axpy(f::<M>(2.0), 0, 1);
1897 assert_eq!(a.get_index(0, 0), f::<M>(1.0));
1898 assert_eq!(a.get_index(0, 1), f::<M>(4.0));
1899 assert_eq!(a.get_index(1, 0), f::<M>(3.0));
1900 assert_eq!(a.get_index(1, 1), f::<M>(10.0));
1901 }
1902
1903 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1904 pub fn test_batched_mat_mul<M: DenseMatrix>(ctx: M::C) {
1905 assert_eq!(ctx.nbatch(), 2);
1906 let a = M::from_vec(
1907 2,
1908 2,
1909 vec![
1910 f::<M>(1.0),
1911 f::<M>(3.0),
1912 f::<M>(2.0),
1913 f::<M>(4.0),
1914 f::<M>(2.0),
1915 f::<M>(1.0),
1916 f::<M>(0.0),
1917 f::<M>(3.0),
1918 ],
1919 ctx.clone(),
1920 );
1921 let b = M::from_vec(
1922 2,
1923 2,
1924 vec![
1925 f::<M>(2.0),
1926 f::<M>(1.0),
1927 f::<M>(0.0),
1928 f::<M>(3.0),
1929 f::<M>(1.0),
1930 f::<M>(0.0),
1931 f::<M>(2.0),
1932 f::<M>(1.0),
1933 ],
1934 ctx.clone(),
1935 );
1936 let c = a.mat_mul(&b);
1937 assert_eq!(c.get_index(0, 0), f::<M>(4.0));
1938 assert_eq!(c.get_index(1, 0), f::<M>(10.0));
1939 assert_eq!(c.get_index(0, 1), f::<M>(6.0));
1940 assert_eq!(c.get_index(1, 1), f::<M>(12.0));
1941 }
1942
1943 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1944 pub fn test_batched_from_diagonal_dense<M: DenseMatrix>(ctx: M::C) {
1945 assert_eq!(ctx.nbatch(), 2);
1946 let v = M::V::from_vec(
1947 vec![f::<M>(2.0), f::<M>(3.0), f::<M>(4.0), f::<M>(5.0)],
1948 ctx,
1949 );
1950 let a = M::from_diagonal(&v);
1951 assert_eq!(a.nrows(), 2);
1952 assert_eq!(a.ncols(), 2);
1953 assert_eq!(a.get_index(0, 0), f::<M>(2.0));
1954 assert_eq!(a.get_index(1, 1), f::<M>(3.0));
1955 assert_eq!(a.get_index(0, 1), f::<M>(0.0));
1956 assert_eq!(a.get_index(1, 0), f::<M>(0.0));
1957 }
1958
1959 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1960 fn make_strided_matrix<M: DenseMatrix>(nbatch: usize) -> M {
1961 let ctx = M::C::default().clone_with_nbatch(nbatch).unwrap();
1962 let nrows = 3;
1963 let ncols = 4;
1964 let mut data = Vec::with_capacity(nrows * ncols * nbatch);
1965 for b in 0..nbatch {
1966 for col in 0..ncols {
1967 for row in 0..nrows {
1968 data.push(f::<M>(row as f64 + col as f64 * 10.0 + b as f64 * 100.0));
1969 }
1970 }
1971 }
1972 M::from_vec(nrows, ncols, data, ctx)
1973 }
1974
1975 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1976 pub fn test_strided_matrix_view_into_owned<M: DenseMatrix>(ctx: M::C) {
1977 let matrix = make_strided_matrix::<M>(ctx.nbatch());
1978 let view = matrix.columns(0, 2);
1979 let owned = view.into_owned();
1980 assert_eq!(owned.nrows(), 3);
1981 assert_eq!(owned.ncols(), 2);
1982 assert_eq!(owned.get_index(0, 0), f::<M>(0.0));
1984 assert_eq!(owned.get_index(1, 0), f::<M>(1.0));
1985 assert_eq!(owned.get_index(2, 0), f::<M>(2.0));
1986 assert_eq!(owned.get_index(0, 1), f::<M>(10.0));
1988 assert_eq!(owned.get_index(1, 1), f::<M>(11.0));
1989 assert_eq!(owned.get_index(2, 1), f::<M>(12.0));
1990 }
1991
1992 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1993 pub fn test_strided_matrix_view_add_owned<M: DenseMatrix>(ctx: M::C) {
1994 let matrix = make_strided_matrix::<M>(ctx.nbatch());
1995 let view = matrix.columns(0, 2);
1996 let rhs = M::from_vec(
1998 3,
1999 2,
2000 vec![
2001 f::<M>(1.0),
2002 f::<M>(2.0),
2003 f::<M>(3.0),
2004 f::<M>(4.0),
2005 f::<M>(5.0),
2006 f::<M>(6.0),
2007 ],
2008 M::C::default(),
2009 );
2010 let result = view + &rhs;
2011 assert_eq!(result.get_index(0, 0), f::<M>(1.0));
2013 assert_eq!(result.get_index(0, 1), f::<M>(14.0));
2014 }
2015
2016 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2017 pub fn test_strided_matrix_view_sub_owned<M: DenseMatrix>(ctx: M::C) {
2018 let matrix = make_strided_matrix::<M>(ctx.nbatch());
2019 let view = matrix.columns(0, 2);
2020 let rhs = M::from_vec(
2021 3,
2022 2,
2023 vec![
2024 f::<M>(0.0),
2025 f::<M>(1.0),
2026 f::<M>(2.0),
2027 f::<M>(10.0),
2028 f::<M>(11.0),
2029 f::<M>(12.0),
2030 ],
2031 M::C::default(),
2032 );
2033 let result = view - &rhs;
2034 assert_eq!(result.get_index(0, 0), f::<M>(0.0));
2036 assert_eq!(result.get_index(0, 1), f::<M>(0.0));
2037 }
2038
2039 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2040 pub fn test_strided_matrix_view_mul_scalar<M: DenseMatrix>(ctx: M::C) {
2041 let matrix = make_strided_matrix::<M>(ctx.nbatch());
2042 let view = matrix.columns(0, 2);
2043 let result = view * Scale(f::<M>(2.0));
2044 assert_eq!(result.get_index(0, 0), f::<M>(0.0));
2045 assert_eq!(result.get_index(1, 0), f::<M>(2.0));
2046 assert_eq!(result.get_index(0, 1), f::<M>(20.0));
2047 }
2048
2049 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2050 pub fn test_strided_matrix_view_mut_add_assign_view<M: DenseMatrix>(ctx: M::C) {
2051 let mut a = make_strided_matrix::<M>(ctx.nbatch());
2052 let b = make_strided_matrix::<M>(ctx.nbatch());
2053 {
2054 let mut a_view = a.columns_mut(0, 2);
2055 let b_view = b.columns(2, 4);
2056 a_view += &b_view;
2057 }
2058 assert_eq!(a.get_index(0, 0), f::<M>(20.0));
2063 assert_eq!(a.get_index(1, 0), f::<M>(22.0));
2064 assert_eq!(a.get_index(0, 1), f::<M>(40.0));
2065 }
2066
2067 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2068 pub fn test_strided_matrix_view_mut_sub_assign_view<M: DenseMatrix>(ctx: M::C) {
2069 let mut a = make_strided_matrix::<M>(ctx.nbatch());
2070 let b = make_strided_matrix::<M>(ctx.nbatch());
2071 {
2072 let mut a_view = a.columns_mut(0, 2);
2073 let b_view = b.columns(0, 2);
2074 a_view -= &b_view;
2075 }
2076 assert_eq!(a.get_index(0, 0), f::<M>(0.0));
2078 assert_eq!(a.get_index(1, 0), f::<M>(0.0));
2079 }
2080
2081 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2082 pub fn test_strided_matrix_view_mut_mul_assign_scalar<M: DenseMatrix>(ctx: M::C) {
2083 let mut a = make_strided_matrix::<M>(ctx.nbatch());
2084 {
2085 let mut a_view = a.columns_mut(0, 2);
2086 a_view *= Scale(f::<M>(2.0));
2087 }
2088 assert_eq!(a.get_index(0, 0), f::<M>(0.0));
2089 assert_eq!(a.get_index(1, 0), f::<M>(2.0));
2090 assert_eq!(a.get_index(0, 1), f::<M>(20.0));
2091 }
2092
2093 pub fn test_view_mut_into_owned<M: DenseMatrix>() {
2096 let mut a = M::from_vec(
2097 2,
2098 3,
2099 vec![
2100 f::<M>(1.0),
2101 f::<M>(2.0),
2102 f::<M>(3.0),
2103 f::<M>(4.0),
2104 f::<M>(5.0),
2105 f::<M>(6.0),
2106 ],
2107 Default::default(),
2108 );
2109 let owned = a.columns_mut(0, 2).into_owned();
2110 assert_eq!(owned.nrows(), 2);
2111 assert_eq!(owned.ncols(), 2);
2112 assert_eq!(owned.get_index(0, 0), f::<M>(1.0));
2113 assert_eq!(owned.get_index(1, 0), f::<M>(2.0));
2114 assert_eq!(owned.get_index(0, 1), f::<M>(3.0));
2115 assert_eq!(owned.get_index(1, 1), f::<M>(4.0));
2116 }
2117
2118 pub fn test_view_mut_add_assign_view_mut<M: DenseMatrix>() {
2119 let mut a = M::from_vec(
2120 2,
2121 2,
2122 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
2123 Default::default(),
2124 );
2125 let mut b = M::from_vec(
2126 2,
2127 2,
2128 vec![f::<M>(10.0), f::<M>(30.0), f::<M>(20.0), f::<M>(40.0)],
2129 Default::default(),
2130 );
2131 {
2132 let mut a_view = a.columns_mut(0, 2);
2133 let b_view = b.columns_mut(0, 2);
2134 a_view += &b_view;
2135 }
2136 assert_eq!(a.get_index(0, 0), f::<M>(11.0));
2137 assert_eq!(a.get_index(1, 0), f::<M>(33.0));
2138 assert_eq!(a.get_index(0, 1), f::<M>(22.0));
2139 assert_eq!(a.get_index(1, 1), f::<M>(44.0));
2140 }
2141
2142 pub fn test_view_mut_sub_assign_view_mut<M: DenseMatrix>() {
2143 let mut a = M::from_vec(
2144 2,
2145 2,
2146 vec![f::<M>(10.0), f::<M>(30.0), f::<M>(20.0), f::<M>(40.0)],
2147 Default::default(),
2148 );
2149 let mut b = M::from_vec(
2150 2,
2151 2,
2152 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
2153 Default::default(),
2154 );
2155 {
2156 let mut a_view = a.columns_mut(0, 2);
2157 let b_view = b.columns_mut(0, 2);
2158 a_view -= &b_view;
2159 }
2160 assert_eq!(a.get_index(0, 0), f::<M>(9.0));
2161 assert_eq!(a.get_index(1, 0), f::<M>(27.0));
2162 assert_eq!(a.get_index(0, 1), f::<M>(18.0));
2163 assert_eq!(a.get_index(1, 1), f::<M>(36.0));
2164 }
2165
2166 pub fn test_gemm_oo_on_columns<M: DenseMatrix>() {
2167 let a = M::from_vec(
2169 2,
2170 2,
2171 vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
2172 Default::default(),
2173 );
2174 let b = M::from_vec(
2176 2,
2177 2,
2178 vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(1.0)],
2179 Default::default(),
2180 );
2181 let mut result = M::zeros(2, 3, Default::default());
2182 {
2183 let mut r_view = result.columns_mut(0, 2);
2184 r_view.gemm_oo(f::<M>(1.0), &a, &b, f::<M>(0.0));
2185 }
2186 assert_eq!(result.get_index(0, 0), f::<M>(1.0));
2188 assert_eq!(result.get_index(1, 0), f::<M>(3.0));
2189 assert_eq!(result.get_index(0, 1), f::<M>(2.0));
2190 assert_eq!(result.get_index(1, 1), f::<M>(4.0));
2191 assert_eq!(result.get_index(0, 2), f::<M>(0.0));
2192 assert_eq!(result.get_index(1, 2), f::<M>(0.0));
2193 }
2194
2195 pub fn test_try_from_triplets_wrong_length<M: Matrix>() {
2196 let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
2197 let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0)];
2199 let _ = M::try_from_triplets(2, 2, indices, values, Default::default());
2200 }
2201
2202 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2205 pub fn test_strided_matrix_view_mut_into_owned<M: DenseMatrix>(ctx: M::C) {
2206 let mut matrix = make_strided_matrix::<M>(ctx.nbatch());
2207 let owned = matrix.columns_mut(0, 2).into_owned();
2208 assert_eq!(owned.nrows(), 3);
2209 assert_eq!(owned.ncols(), 2);
2210 assert_eq!(owned.get_index(0, 0), f::<M>(0.0));
2212 assert_eq!(owned.get_index(1, 0), f::<M>(1.0));
2213 assert_eq!(owned.get_index(2, 0), f::<M>(2.0));
2214 assert_eq!(owned.get_index(0, 1), f::<M>(10.0));
2215 assert_eq!(owned.get_index(1, 1), f::<M>(11.0));
2216 assert_eq!(owned.get_index(2, 1), f::<M>(12.0));
2217 let (_, vals) = owned.triplet_iter();
2219 let vals: Vec<_> = vals.collect();
2220 assert_eq!(
2221 vals,
2222 vec![
2223 f::<M>(0.0),
2224 f::<M>(1.0),
2225 f::<M>(2.0),
2226 f::<M>(10.0),
2227 f::<M>(11.0),
2228 f::<M>(12.0),
2229 f::<M>(100.0),
2230 f::<M>(101.0),
2231 f::<M>(102.0),
2232 f::<M>(110.0),
2233 f::<M>(111.0),
2234 f::<M>(112.0),
2235 ]
2236 );
2237 }
2238
2239 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2240 pub fn test_batched_view_mut_add_assign_view_mut<M: DenseMatrix>(ctx: M::C) {
2241 assert_eq!(ctx.nbatch(), 2);
2242 let mut a = M::from_vec(
2244 2,
2245 2,
2246 vec![
2247 f::<M>(1.0),
2248 f::<M>(3.0),
2249 f::<M>(2.0),
2250 f::<M>(4.0),
2251 f::<M>(5.0),
2252 f::<M>(7.0),
2253 f::<M>(6.0),
2254 f::<M>(8.0),
2255 ],
2256 ctx.clone(),
2257 );
2258 let mut b = M::from_vec(
2259 2,
2260 2,
2261 vec![
2262 f::<M>(10.0),
2263 f::<M>(30.0),
2264 f::<M>(20.0),
2265 f::<M>(40.0),
2266 f::<M>(50.0),
2267 f::<M>(70.0),
2268 f::<M>(60.0),
2269 f::<M>(80.0),
2270 ],
2271 ctx,
2272 );
2273 {
2274 let mut a_view = a.columns_mut(0, 2);
2275 let b_view = b.columns_mut(0, 2);
2276 a_view += &b_view;
2277 }
2278 let (_, vals) = a.triplet_iter();
2279 let vals: Vec<_> = vals.collect();
2280 assert_eq!(
2281 vals,
2282 vec![
2283 f::<M>(11.0),
2284 f::<M>(33.0),
2285 f::<M>(22.0),
2286 f::<M>(44.0),
2287 f::<M>(55.0),
2288 f::<M>(77.0),
2289 f::<M>(66.0),
2290 f::<M>(88.0),
2291 ]
2292 );
2293 }
2294
2295 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2296 pub fn test_batched_view_mut_sub_assign_view_mut<M: DenseMatrix>(ctx: M::C) {
2297 assert_eq!(ctx.nbatch(), 2);
2298 let mut a = M::from_vec(
2299 2,
2300 2,
2301 vec![
2302 f::<M>(10.0),
2303 f::<M>(30.0),
2304 f::<M>(20.0),
2305 f::<M>(40.0),
2306 f::<M>(50.0),
2307 f::<M>(70.0),
2308 f::<M>(60.0),
2309 f::<M>(80.0),
2310 ],
2311 ctx.clone(),
2312 );
2313 let mut b = M::from_vec(
2314 2,
2315 2,
2316 vec![
2317 f::<M>(1.0),
2318 f::<M>(3.0),
2319 f::<M>(2.0),
2320 f::<M>(4.0),
2321 f::<M>(5.0),
2322 f::<M>(7.0),
2323 f::<M>(6.0),
2324 f::<M>(8.0),
2325 ],
2326 ctx,
2327 );
2328 {
2329 let mut a_view = a.columns_mut(0, 2);
2330 let b_view = b.columns_mut(0, 2);
2331 a_view -= &b_view;
2332 }
2333 let (_, vals) = a.triplet_iter();
2334 let vals: Vec<_> = vals.collect();
2335 assert_eq!(
2336 vals,
2337 vec![
2338 f::<M>(9.0),
2339 f::<M>(27.0),
2340 f::<M>(18.0),
2341 f::<M>(36.0),
2342 f::<M>(45.0),
2343 f::<M>(63.0),
2344 f::<M>(54.0),
2345 f::<M>(72.0),
2346 ]
2347 );
2348 }
2349
2350 #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2351 pub fn test_batched_gemm_oo_on_columns<M: DenseMatrix>(ctx: M::C) {
2352 assert_eq!(ctx.nbatch(), 2);
2353 let a = M::from_vec(
2355 2,
2356 2,
2357 vec![
2358 f::<M>(1.0),
2359 f::<M>(3.0),
2360 f::<M>(2.0),
2361 f::<M>(4.0),
2362 f::<M>(5.0),
2363 f::<M>(7.0),
2364 f::<M>(6.0),
2365 f::<M>(8.0),
2366 ],
2367 ctx.clone(),
2368 );
2369 let b = M::from_vec(
2371 2,
2372 2,
2373 vec![
2374 f::<M>(1.0),
2375 f::<M>(0.0),
2376 f::<M>(0.0),
2377 f::<M>(1.0),
2378 f::<M>(2.0),
2379 f::<M>(0.0),
2380 f::<M>(0.0),
2381 f::<M>(2.0),
2382 ],
2383 ctx.clone(),
2384 );
2385 let mut result = M::zeros(2, 3, ctx);
2386 {
2387 let mut r_view = result.columns_mut(0, 2);
2388 r_view.gemm_oo(f::<M>(1.0), &a, &b, f::<M>(0.0));
2389 }
2390 assert_eq!(result.get_index(0, 0), f::<M>(1.0));
2392 assert_eq!(result.get_index(1, 0), f::<M>(3.0));
2393 assert_eq!(result.get_index(0, 1), f::<M>(2.0));
2394 assert_eq!(result.get_index(1, 1), f::<M>(4.0));
2395 let (_, vals) = result.triplet_iter();
2396 let vals: Vec<_> = vals.collect();
2397 assert_eq!(
2398 vals,
2399 vec![
2400 f::<M>(1.0),
2401 f::<M>(3.0),
2402 f::<M>(2.0),
2403 f::<M>(4.0),
2404 f::<M>(0.0),
2405 f::<M>(0.0),
2406 f::<M>(10.0),
2407 f::<M>(14.0),
2408 f::<M>(12.0),
2409 f::<M>(16.0),
2410 f::<M>(0.0),
2411 f::<M>(0.0),
2412 ]
2413 );
2414 }
2415}
2416
2417#[cfg(test)]
2418macro_rules! generate_matrix_tests_nonbatched {
2419 ($suffix:ident, $M:ty) => {
2420 paste::paste! {
2421 #[test]
2422 fn [<test_zeros_ $suffix>]() {
2423 $crate::matrix::tests::test_zeros::<$M>();
2424 }
2425 #[test]
2426 fn [<test_from_diagonal_ $suffix>]() {
2427 $crate::matrix::tests::test_from_diagonal::<$M>();
2428 }
2429 #[test]
2430 fn [<test_gemv_ $suffix>]() {
2431 $crate::matrix::tests::test_gemv::<$M>();
2432 }
2433 #[test]
2434 fn [<test_set_column_ $suffix>]() {
2435 $crate::matrix::tests::test_set_column::<$M>();
2436 }
2437 #[test]
2438 fn [<test_copy_from_ $suffix>]() {
2439 $crate::matrix::tests::test_copy_from::<$M>();
2440 }
2441 #[test]
2442 fn [<test_scale_add_and_assign_ $suffix>]() {
2443 $crate::matrix::tests::test_scale_add_and_assign::<$M>();
2444 }
2445 #[test]
2446 fn [<test_partition_indices_ $suffix>]() {
2447 $crate::matrix::tests::test_partition_indices_by_zero_diagonal::<$M>();
2448 }
2449 #[test]
2450 fn [<test_mul_scalar_ $suffix>]() {
2451 $crate::matrix::tests::test_mul_scalar::<$M>();
2452 }
2453 #[test]
2454 fn [<test_add_column_to_vector_ $suffix>]() {
2455 $crate::matrix::tests::test_add_column_to_vector::<$M>();
2456 }
2457 #[test]
2458 #[should_panic]
2459 fn [<test_try_from_triplets_wrong_length_ $suffix>]() {
2460 $crate::matrix::tests::test_try_from_triplets_wrong_length::<$M>();
2461 }
2462 }
2463 };
2464}
2465
2466#[cfg(test)]
2467#[cfg_attr(not(feature = "cuda"), allow(unused_macros))]
2468macro_rules! generate_matrix_tests_batched {
2469 ($suffix:ident, $M:ty, $ctx1:expr, $ctx2:expr) => {
2470 paste::paste! {
2471 #[test]
2472 fn [<test_batched_add_column_to_vector_ $suffix>]() {
2473 $crate::matrix::tests::test_batched_add_column_to_vector_m::<$M>($ctx2);
2474 }
2475 #[test]
2476 fn [<test_batched_set_data_with_indices_ $suffix>]() {
2477 $crate::matrix::tests::test_batched_set_data_with_indices_m::<$M>($ctx2);
2478 }
2479 #[test]
2480 fn [<test_batched_gather_ $suffix>]() {
2481 $crate::matrix::tests::test_batched_gather_m::<$M>($ctx2);
2482 }
2483 #[test]
2484 fn [<test_batched_mul_scalar_ $suffix>]() {
2485 $crate::matrix::tests::test_batched_mul_scalar_m::<$M>($ctx2);
2486 }
2487 #[test]
2488 fn [<test_batched_partition_indices_ $suffix>]() {
2489 $crate::matrix::tests::test_batched_partition_indices::<$M>($ctx2);
2490 }
2491 #[test]
2492 fn [<test_batched_zeros_ $suffix>]() {
2493 $crate::matrix::tests::test_batched_zeros_m::<$M>($ctx2);
2494 }
2495 #[test]
2496 fn [<test_batched_gemv_ $suffix>]() {
2497 $crate::matrix::tests::test_batched_gemv_m::<$M>($ctx2);
2498 }
2499 #[test]
2500 fn [<test_batched_gemv_broadcast_x_ $suffix>]() {
2501 $crate::matrix::tests::test_batched_gemv_broadcast_x_m::<$M>($ctx2);
2502 }
2503 #[test]
2504 fn [<test_batched_gemv_broadcast_mat_ $suffix>]() {
2505 $crate::matrix::tests::test_batched_gemv_broadcast_mat_m::<$M>($ctx2);
2506 }
2507 #[test]
2508 fn [<test_batched_from_diagonal_ $suffix>]() {
2509 $crate::matrix::tests::test_batched_from_diagonal_m::<$M>($ctx2);
2510 }
2511 #[test]
2512 fn [<test_batched_copy_from_ $suffix>]() {
2513 $crate::matrix::tests::test_batched_copy_from_m::<$M>($ctx2);
2514 }
2515 #[test]
2516 fn [<test_batched_set_column_ $suffix>]() {
2517 $crate::matrix::tests::test_batched_set_column_m::<$M>($ctx2);
2518 }
2519 #[test]
2520 fn [<test_batched_scale_add_ $suffix>]() {
2521 $crate::matrix::tests::test_batched_scale_add_and_assign_m::<$M>($ctx2);
2522 }
2523 }
2524 };
2525}
2526
2527#[cfg(test)]
2528macro_rules! generate_dense_matrix_tests_nonbatched {
2529 ($suffix:ident, $M:ty) => {
2530 paste::paste! {
2531 #[test]
2532 fn [<test_from_vec_ $suffix>]() {
2533 $crate::matrix::tests::test_from_vec::<$M>();
2534 }
2535 #[test]
2536 fn [<test_from_diagonal_dense_ $suffix>]() {
2537 $crate::matrix::tests::test_from_diagonal_dense::<$M>();
2538 }
2539 #[test]
2540 fn [<test_gemm_ $suffix>]() {
2541 $crate::matrix::tests::test_gemm::<$M>();
2542 }
2543 #[test]
2544 fn [<test_mat_mul_ $suffix>]() {
2545 $crate::matrix::tests::test_mat_mul::<$M>();
2546 }
2547 #[test]
2548 fn [<test_columns_view_ $suffix>]() {
2549 $crate::matrix::tests::test_columns_view::<$M>();
2550 }
2551 #[test]
2552 fn [<test_column_view_ $suffix>]() {
2553 $crate::matrix::tests::test_column_view::<$M>();
2554 }
2555 #[test]
2556 fn [<test_column_axpy_ $suffix>]() {
2557 $crate::matrix::tests::test_column_axpy::<$M>();
2558 }
2559 #[test]
2560 fn [<test_resize_cols_ $suffix>]() {
2561 $crate::matrix::tests::test_resize_cols::<$M>();
2562 }
2563 #[test]
2564 fn [<test_add_ $suffix>]() {
2565 $crate::matrix::tests::test_add::<$M>();
2566 }
2567 #[test]
2568 fn [<test_sub_ $suffix>]() {
2569 $crate::matrix::tests::test_sub::<$M>();
2570 }
2571 #[test]
2572 fn [<test_add_assign_ $suffix>]() {
2573 $crate::matrix::tests::test_add_assign::<$M>();
2574 }
2575 #[test]
2576 fn [<test_sub_assign_ $suffix>]() {
2577 $crate::matrix::tests::test_sub_assign::<$M>();
2578 }
2579 #[test]
2580 fn [<test_gather_ $suffix>]() {
2581 $crate::matrix::tests::test_gather::<$M>();
2582 }
2583 #[test]
2584 fn [<test_set_data_with_indices_ $suffix>]() {
2585 $crate::matrix::tests::test_set_data_with_indices::<$M>();
2586 }
2587 #[test]
2588 fn [<test_mul_assign_scalar_ $suffix>]() {
2589 $crate::matrix::tests::test_mul_assign_scalar::<$M>();
2590 }
2591 #[test]
2592 fn [<test_view_mut_into_owned_ $suffix>]() {
2593 $crate::matrix::tests::test_view_mut_into_owned::<$M>();
2594 }
2595 #[test]
2596 fn [<test_view_mut_add_assign_view_mut_ $suffix>]() {
2597 $crate::matrix::tests::test_view_mut_add_assign_view_mut::<$M>();
2598 }
2599 #[test]
2600 fn [<test_view_mut_sub_assign_view_mut_ $suffix>]() {
2601 $crate::matrix::tests::test_view_mut_sub_assign_view_mut::<$M>();
2602 }
2603 #[test]
2604 fn [<test_gemm_oo_on_columns_ $suffix>]() {
2605 $crate::matrix::tests::test_gemm_oo_on_columns::<$M>();
2606 }
2607 }
2608 };
2609}
2610
2611#[cfg(test)]
2612#[cfg_attr(not(feature = "cuda"), allow(unused_macros))]
2613macro_rules! generate_dense_matrix_tests_batched {
2614 ($suffix:ident, $M:ty, $ctx1:expr, $ctx2:expr) => {
2615 paste::paste! {
2616 #[test]
2617 fn [<test_batched_column_axpy_ $suffix>]() {
2618 $crate::matrix::tests::test_batched_column_axpy::<$M>($ctx2);
2619 }
2620 #[test]
2621 fn [<test_batched_mat_mul_ $suffix>]() {
2622 $crate::matrix::tests::test_batched_mat_mul::<$M>($ctx2);
2623 }
2624 #[test]
2625 fn [<test_batched_from_diagonal_dense_ $suffix>]() {
2626 $crate::matrix::tests::test_batched_from_diagonal_dense::<$M>($ctx2);
2627 }
2628 #[test]
2629 fn [<test_batched_from_vec_ $suffix>]() {
2630 $crate::matrix::tests::test_batched_from_vec::<$M>($ctx2);
2631 }
2632 #[test]
2633 fn [<test_batched_gemm_ $suffix>]() {
2634 $crate::matrix::tests::test_batched_gemm::<$M>($ctx2);
2635 }
2636 #[test]
2637 fn [<test_batched_columns_ $suffix>]() {
2638 $crate::matrix::tests::test_batched_columns::<$M>($ctx2);
2639 }
2640 #[test]
2641 fn [<test_batched_gemv_o_on_columns_ $suffix>]() {
2642 $crate::matrix::tests::test_batched_gemv_o_on_columns::<$M>($ctx2);
2643 }
2644 #[test]
2645 fn [<test_batched_gemm_vo_on_columns_ $suffix>]() {
2646 $crate::matrix::tests::test_batched_gemm_vo_on_columns::<$M>($ctx2);
2647 }
2648 #[test]
2649 fn [<test_batched_gemm_broadcast_b_ $suffix>]() {
2650 $crate::matrix::tests::test_batched_gemm_broadcast_b::<$M>($ctx2);
2651 }
2652 #[test]
2653 fn [<test_batched_gemv_o_broadcast_x_ $suffix>]() {
2654 $crate::matrix::tests::test_batched_gemv_o_broadcast_x::<$M>($ctx2);
2655 }
2656 #[test]
2657 fn [<test_batched_gemv_v_broadcast_mat_ $suffix>]() {
2658 $crate::matrix::tests::test_batched_gemv_v_broadcast_mat::<$M>($ctx2);
2659 }
2660 #[test]
2661 fn [<test_batched_gemv_o_broadcast_mat_ $suffix>]() {
2662 $crate::matrix::tests::test_batched_gemv_o_broadcast_mat::<$M>($ctx2);
2663 }
2664 #[test]
2665 fn [<test_batched_gemm_vo_broadcast_b_ $suffix>]() {
2666 $crate::matrix::tests::test_batched_gemm_vo_broadcast_b::<$M>($ctx2);
2667 }
2668 #[test]
2669 fn [<test_batched_gemm_broadcast_a_ $suffix>]() {
2670 $crate::matrix::tests::test_batched_gemm_broadcast_a::<$M>($ctx2);
2671 }
2672 #[test]
2673 fn [<test_batched_gemm_vo_broadcast_a_ $suffix>]() {
2674 $crate::matrix::tests::test_batched_gemm_vo_broadcast_a::<$M>($ctx2);
2675 }
2676 #[test]
2677 fn [<test_batched_resize_cols_ $suffix>]() {
2678 $crate::matrix::tests::test_batched_resize_cols::<$M>($ctx2);
2679 }
2680 #[test]
2681 fn [<test_batched_combine_ $suffix>]() {
2682 $crate::matrix::tests::test_batched_combine::<$M>($ctx2);
2683 }
2684 #[test]
2685 #[should_panic(expected = "incompatible nbatch")]
2686 fn [<test_batched_gemv_incompatible_ $suffix>]() {
2687 $crate::matrix::tests::test_batched_gemv_incompatible::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
2688 }
2689 #[test]
2690 #[should_panic(expected = "incompatible nbatch")]
2691 fn [<test_batched_gemm_incompatible_ $suffix>]() {
2692 $crate::matrix::tests::test_batched_gemm_incompatible::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
2693 }
2694 #[test]
2695 #[should_panic(expected = "incompatible nbatch")]
2696 fn [<test_batched_gemm_incompatible_a_ $suffix>]() {
2697 $crate::matrix::tests::test_batched_gemm_incompatible_a::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
2698 }
2699 #[test]
2700 fn [<test_strided_matrix_view_into_owned_ $suffix>]() {
2701 $crate::matrix::tests::test_strided_matrix_view_into_owned::<$M>($ctx2);
2702 }
2703 #[test]
2704 fn [<test_strided_matrix_view_add_owned_ $suffix>]() {
2705 $crate::matrix::tests::test_strided_matrix_view_add_owned::<$M>($ctx2);
2706 }
2707 #[test]
2708 fn [<test_strided_matrix_view_sub_owned_ $suffix>]() {
2709 $crate::matrix::tests::test_strided_matrix_view_sub_owned::<$M>($ctx2);
2710 }
2711 #[test]
2712 fn [<test_strided_matrix_view_mul_scalar_ $suffix>]() {
2713 $crate::matrix::tests::test_strided_matrix_view_mul_scalar::<$M>($ctx2);
2714 }
2715 #[test]
2716 fn [<test_strided_matrix_view_mut_add_assign_view_ $suffix>]() {
2717 $crate::matrix::tests::test_strided_matrix_view_mut_add_assign_view::<$M>($ctx2);
2718 }
2719 #[test]
2720 fn [<test_strided_matrix_view_mut_sub_assign_view_ $suffix>]() {
2721 $crate::matrix::tests::test_strided_matrix_view_mut_sub_assign_view::<$M>($ctx2);
2722 }
2723 #[test]
2724 fn [<test_strided_matrix_view_mut_mul_assign_scalar_ $suffix>]() {
2725 $crate::matrix::tests::test_strided_matrix_view_mut_mul_assign_scalar::<$M>($ctx2);
2726 }
2727 #[test]
2728 fn [<test_strided_matrix_view_mut_into_owned_ $suffix>]() {
2729 $crate::matrix::tests::test_strided_matrix_view_mut_into_owned::<$M>($ctx2);
2730 }
2731 #[test]
2732 fn [<test_batched_view_mut_add_assign_view_mut_ $suffix>]() {
2733 $crate::matrix::tests::test_batched_view_mut_add_assign_view_mut::<$M>($ctx2);
2734 }
2735 #[test]
2736 fn [<test_batched_view_mut_sub_assign_view_mut_ $suffix>]() {
2737 $crate::matrix::tests::test_batched_view_mut_sub_assign_view_mut::<$M>($ctx2);
2738 }
2739 #[test]
2740 fn [<test_batched_gemm_oo_on_columns_ $suffix>]() {
2741 $crate::matrix::tests::test_batched_gemm_oo_on_columns::<$M>($ctx2);
2742 }
2743 }
2744 };
2745}
2746
2747#[cfg(test)]
2748#[cfg_attr(not(feature = "cuda"), allow(unused_imports))]
2749pub(crate) use generate_dense_matrix_tests_batched;
2750#[cfg(test)]
2751pub(crate) use generate_dense_matrix_tests_nonbatched;
2752#[cfg(test)]
2753#[cfg_attr(not(feature = "cuda"), allow(unused_imports))]
2754pub(crate) use generate_matrix_tests_batched;
2755#[cfg(test)]
2756pub(crate) use generate_matrix_tests_nonbatched;