use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
use crate::error::DiffsolError;
use crate::scalar::Scale;
use crate::vector::VectorHost;
use crate::{Context, IndexType, Scalar, Vector, VectorIndex};
use extract_block::combine;
use num_traits::{One, Zero};
use sparsity::{Dense, MatrixSparsity, MatrixSparsityRef};
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(feature = "nalgebra")]
pub mod dense_nalgebra_serial;
#[cfg(feature = "faer")]
pub mod dense_faer_serial;
#[cfg(feature = "faer")]
pub mod sparse_faer;
pub mod default_solver;
pub mod extract_block;
pub mod sparsity;
#[macro_use]
mod utils;
pub trait MatrixCommon: Sized + Debug {
type V: Vector<T = Self::T, C = Self::C, Index: VectorIndex<C = Self::C>>;
type T: Scalar;
type C: Context;
type Inner;
fn nrows(&self) -> IndexType;
fn ncols(&self) -> IndexType;
fn inner(&self) -> &Self::Inner;
}
impl<M> MatrixCommon for &M
where
M: MatrixCommon,
{
type T = M::T;
type V = M::V;
type C = M::C;
type Inner = M::Inner;
fn nrows(&self) -> IndexType {
M::nrows(*self)
}
fn ncols(&self) -> IndexType {
M::ncols(*self)
}
fn inner(&self) -> &Self::Inner {
M::inner(*self)
}
}
impl<M> MatrixCommon for &mut M
where
M: MatrixCommon,
{
type T = M::T;
type V = M::V;
type C = M::C;
type Inner = M::Inner;
fn ncols(&self) -> IndexType {
M::ncols(*self)
}
fn nrows(&self) -> IndexType {
M::nrows(*self)
}
fn inner(&self) -> &Self::Inner {
M::inner(*self)
}
}
pub trait MatrixOpsByValue<Rhs = Self, Output = Self>:
MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
{
}
impl<M, Rhs, Output> MatrixOpsByValue<Rhs, Output> for M where
M: MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
{
}
pub trait MatrixMutOpsByValue<Rhs = Self>: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
impl<M, Rhs> MatrixMutOpsByValue<Rhs> for M where M: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
pub trait MatrixRef<M: MatrixCommon>: Mul<Scale<M::T>, Output = M> {}
impl<RefT, M: MatrixCommon> MatrixRef<M> for RefT where RefT: Mul<Scale<M::T>, Output = M> {}
pub trait MatrixViewMut<'a>:
for<'b> MatrixMutOpsByValue<&'b Self>
+ for<'b> MatrixMutOpsByValue<&'b Self::View>
+ MulAssign<Scale<Self::T>>
{
type Owned;
type View;
fn into_owned(self) -> Self::Owned;
fn gemm_oo(&mut self, alpha: Self::T, a: &Self::Owned, b: &Self::Owned, beta: Self::T);
fn gemm_vo(&mut self, alpha: Self::T, a: &Self::View, b: &Self::Owned, beta: Self::T);
}
pub trait MatrixView<'a>:
for<'b> MatrixOpsByValue<&'b Self::Owned, Self::Owned> + Mul<Scale<Self::T>, Output = Self::Owned>
{
type Owned;
fn into_owned(self) -> Self::Owned;
fn gemv_v(
&self,
alpha: Self::T,
x: &<Self::V as Vector>::View<'_>,
beta: Self::T,
y: &mut Self::V,
);
fn gemv_o(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
}
pub trait Matrix:
MatrixCommon + Mul<Scale<Self::T>, Output = Self> + Clone + Send + 'static
{
type Sparsity: MatrixSparsity<Self>;
type SparsityRef<'a>: MatrixSparsityRef<'a, Self>
where
Self: 'a;
fn sparsity(&self) -> Option<Self::SparsityRef<'_>>;
fn context(&self) -> &Self::C;
fn inner_mut(&mut self) -> &mut Self::Inner;
fn is_sparse() -> bool {
Self::zeros(1, 1, Default::default()).sparsity().is_some()
}
fn partition_indices_by_zero_diagonal(
&self,
) -> (<Self::V as Vector>::Index, <Self::V as Vector>::Index);
fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
fn copy_from(&mut self, other: &Self);
fn zeros(nrows: IndexType, ncols: IndexType, ctx: Self::C) -> Self;
fn new_from_sparsity(
nrows: IndexType,
ncols: IndexType,
sparsity: Option<Self::Sparsity>,
ctx: Self::C,
) -> Self;
fn from_diagonal(v: &Self::V) -> Self;
fn set_column(&mut self, j: IndexType, v: &Self::V);
fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V);
fn set_data_with_indices(
&mut self,
dst_indices: &<Self::V as Vector>::Index,
src_indices: &<Self::V as Vector>::Index,
data: &Self::V,
);
fn gather(&mut self, other: &Self, indices: &<Self::V as Vector>::Index);
fn split(
&self,
algebraic_indices: &<Self::V as Vector>::Index,
) -> [(Self, <Self::V as Vector>::Index); 4] {
match self.sparsity() {
Some(sp) => sp.split(algebraic_indices).map(|(sp, src_indices)| {
let mut m = Self::new_from_sparsity(
sp.nrows(),
sp.ncols(),
Some(sp),
self.context().clone(),
);
m.gather(self, &src_indices);
(m, src_indices)
}),
None => Dense::<Self>::new(self.nrows(), self.ncols())
.split(algebraic_indices)
.map(|(sp, src_indices)| {
let mut m = Self::new_from_sparsity(
sp.nrows(),
sp.ncols(),
None,
self.context().clone(),
);
m.gather(self, &src_indices);
(m, src_indices)
}),
}
}
fn combine(
ul: &Self,
ur: &Self,
ll: &Self,
lr: &Self,
algebraic_indices: &<Self::V as Vector>::Index,
) -> Self {
combine(ul, ur, ll, lr, algebraic_indices)
}
fn scale_add_and_assign(&mut self, x: &Self, beta: Self::T, y: &Self);
fn triplet_iter(
&self,
) -> (
impl Iterator<Item = (IndexType, IndexType)> + '_,
impl Iterator<Item = Self::T> + '_,
);
fn try_from_triplets(
nrows: IndexType,
ncols: IndexType,
indices: Vec<(IndexType, IndexType)>,
values: Vec<Self::T>,
ctx: Self::C,
) -> Result<Self, DiffsolError>;
}
pub trait MatrixHost: Matrix<V: VectorHost> {}
impl<T: Matrix<V: VectorHost>> MatrixHost for T {}
pub trait DenseMatrix:
Matrix
+ for<'b> MatrixOpsByValue<&'b Self, Self>
+ for<'b> MatrixMutOpsByValue<&'b Self>
+ for<'a, 'b> MatrixOpsByValue<&'b Self::View<'a>, Self>
+ for<'a, 'b> MatrixMutOpsByValue<&'b Self::View<'a>>
{
type View<'a>: MatrixView<'a, Owned = Self, T = Self::T, V = Self::V>
where
Self: 'a;
type ViewMut<'a>: MatrixViewMut<
'a,
Owned = Self,
T = Self::T,
V = Self::V,
View = Self::View<'a>,
>
where
Self: 'a;
fn gemm(&mut self, alpha: Self::T, a: &Self, b: &Self, beta: Self::T);
fn column_axpy(&mut self, alpha: Self::T, j: IndexType, i: IndexType);
fn columns(&self, start: IndexType, end: IndexType) -> Self::View<'_>;
fn column(&self, i: IndexType) -> <Self::V as Vector>::View<'_>;
fn columns_mut(&mut self, start: IndexType, end: IndexType) -> Self::ViewMut<'_>;
fn column_mut(&mut self, i: IndexType) -> <Self::V as Vector>::ViewMut<'_>;
fn set_index(&mut self, i: IndexType, j: IndexType, value: Self::T);
fn get_index(&self, i: IndexType, j: IndexType) -> Self::T;
fn mat_mul(&self, b: &Self) -> Self {
let nrows = self.nrows();
let ncols = b.ncols();
let mut ret = Self::zeros(nrows, ncols, self.context().clone());
ret.gemm(Self::T::one(), self, b, Self::T::zero());
ret
}
fn resize_cols(&mut self, ncols: IndexType);
fn from_vec(nrows: IndexType, ncols: IndexType, data: Vec<Self::T>, ctx: Self::C) -> Self;
}
#[cfg(test)]
pub(crate) mod tests {
use super::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut};
use crate::scalar::Scale;
use crate::{scalar::IndexType, Context, Vector, VectorIndex};
use num_traits::{FromPrimitive, One, Zero};
fn f<M: Matrix>(x: f64) -> M::T {
M::T::from_f64(x).unwrap()
}
fn triplet_values<M: Matrix>(m: &M) -> Vec<M::T> {
let (_, vals) = m.triplet_iter();
vals.collect()
}
fn triplet_indices<M: Matrix>(m: &M) -> Vec<(IndexType, IndexType)> {
let (idx, _) = m.triplet_iter();
idx.collect()
}
pub fn test_partition_indices_by_zero_diagonal<M: Matrix>() {
let indices = vec![(0, 0), (1, 1), (3, 3)];
let values = vec![M::T::one(), M::T::from_f64(2.0).unwrap(), M::T::one()];
let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
let (zero_diagonal_indices, non_zero_diagonal_indices) =
m.partition_indices_by_zero_diagonal();
assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
let indices = vec![(0, 0), (1, 1), (2, 2), (3, 3)];
let values = vec![
M::T::one(),
M::T::from_f64(2.0).unwrap(),
M::T::zero(),
M::T::one(),
];
let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
let (zero_diagonal_indices, non_zero_diagonal_indices) =
m.partition_indices_by_zero_diagonal();
assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
let indices = vec![(0, 0), (1, 1), (2, 2), (3, 3)];
let values = vec![
M::T::one(),
M::T::from_f64(2.0).unwrap(),
M::T::from_f64(3.0).unwrap(),
M::T::one(),
];
let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
let (zero_diagonal_indices, non_zero_diagonal_indices) =
m.partition_indices_by_zero_diagonal();
assert_eq!(
zero_diagonal_indices.clone_as_vec(),
Vec::<IndexType>::new()
);
assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 2, 3]);
}
pub fn test_zeros<M: Matrix>() {
let a = M::zeros(2, 3, Default::default());
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 3);
let vals = triplet_values(&a);
assert!(vals.is_empty() || vals.iter().all(|v| v.is_zero()));
}
pub fn test_from_diagonal<M: Matrix>() {
let v = M::V::from_vec(
vec![f::<M>(2.0), f::<M>(3.0), f::<M>(5.0)],
Default::default(),
);
let a = M::from_diagonal(&v);
assert_eq!(a.nrows(), 3);
assert_eq!(a.ncols(), 3);
let idx = triplet_indices(&a);
let vals = triplet_values(&a);
for &(i, j) in &idx {
let pos = idx.iter().position(|&x| x == (i, j)).unwrap();
if i == j {
assert!(
vals[pos] != M::T::zero(),
"diagonal entry should be non-zero"
);
} else {
assert!(vals[pos].is_zero(), "off-diagonal entry should be zero");
}
}
}
pub fn test_from_diagonal_dense<M: DenseMatrix>() {
let v = M::V::from_vec(
vec![f::<M>(2.0), f::<M>(3.0), f::<M>(5.0)],
Default::default(),
);
let a = M::from_diagonal(&v);
assert_eq!(a.nrows(), 3);
assert_eq!(a.ncols(), 3);
assert_eq!(a.get_index(0, 0), f::<M>(2.0));
assert_eq!(a.get_index(1, 1), f::<M>(3.0));
assert_eq!(a.get_index(2, 2), f::<M>(5.0));
assert_eq!(a.get_index(0, 1), f::<M>(0.0));
assert_eq!(a.get_index(1, 0), f::<M>(0.0));
}
pub fn test_gemv<M: Matrix>() {
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(2.0)], Default::default());
let mut y = M::V::zeros(2, Default::default());
a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(y.clone_as_vec(), vec![f::<M>(5.0), f::<M>(11.0)]);
}
pub fn test_set_column<M: Matrix>() {
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![f::<M>(0.0), f::<M>(0.0), f::<M>(0.0), f::<M>(0.0)];
let mut a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
let v = M::V::from_vec(vec![f::<M>(7.0), f::<M>(8.0)], Default::default());
a.set_column(1, &v);
let idx = triplet_indices(&a);
let vals = triplet_values(&a);
assert_eq!(idx, vec![(0, 0), (1, 0), (0, 1), (1, 1)]);
assert_eq!(
vals,
vec![f::<M>(0.0), f::<M>(0.0), f::<M>(7.0), f::<M>(8.0)]
);
}
pub fn test_copy_from<M: Matrix>() {
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
let mut b = M::zeros(2, 2, Default::default());
b.copy_from(&a);
let vals = triplet_values(&b);
assert_eq!(
vals,
vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)]
);
}
pub fn test_scale_add_and_assign<M: Matrix>() {
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let x_vals = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
let y_vals = vec![f::<M>(10.0), f::<M>(20.0), f::<M>(30.0), f::<M>(40.0)];
let x = M::try_from_triplets(2, 2, indices.clone(), x_vals, Default::default()).unwrap();
let y = M::try_from_triplets(2, 2, indices, y_vals, Default::default()).unwrap();
let mut result = M::zeros(2, 2, Default::default());
result.copy_from(&x);
result.scale_add_and_assign(&x, f::<M>(2.0), &y);
let vals = triplet_values(&result);
assert_eq!(
vals,
vec![f::<M>(21.0), f::<M>(42.0), f::<M>(63.0), f::<M>(84.0)]
);
}
pub fn test_column_axpy<M: DenseMatrix>() {
let mut a = M::zeros(2, 2, Default::default());
a.set_index(0, 0, M::T::one());
a.set_index(0, 1, M::T::from_f64(2.0).unwrap());
a.set_index(1, 0, M::T::from_f64(3.0).unwrap());
a.set_index(1, 1, M::T::from_f64(4.0).unwrap());
a.column_axpy(M::T::from_f64(2.0).unwrap(), 0, 1);
assert_eq!(a.get_index(0, 0), M::T::one());
assert_eq!(a.get_index(0, 1), M::T::from_f64(4.0).unwrap());
assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
assert_eq!(a.get_index(1, 1), M::T::from_f64(10.0).unwrap());
}
pub fn test_resize_cols<M: DenseMatrix>() {
let mut a = M::zeros(2, 2, Default::default());
a.set_index(0, 0, M::T::one());
a.set_index(0, 1, M::T::from_f64(2.0).unwrap());
a.set_index(1, 0, M::T::from_f64(3.0).unwrap());
a.set_index(1, 1, M::T::from_f64(4.0).unwrap());
a.resize_cols(3);
assert_eq!(a.ncols(), 3);
assert_eq!(a.nrows(), 2);
assert_eq!(a.get_index(0, 0), M::T::one());
assert_eq!(a.get_index(0, 1), M::T::from_f64(2.0).unwrap());
assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
assert_eq!(a.get_index(1, 1), M::T::from_f64(4.0).unwrap());
a.set_index(0, 2, M::T::from_f64(5.0).unwrap());
a.set_index(1, 2, M::T::from_f64(6.0).unwrap());
assert_eq!(a.get_index(0, 2), M::T::from_f64(5.0).unwrap());
assert_eq!(a.get_index(1, 2), M::T::from_f64(6.0).unwrap());
a.resize_cols(2);
assert_eq!(a.ncols(), 2);
assert_eq!(a.nrows(), 2);
assert_eq!(a.get_index(0, 0), M::T::one());
assert_eq!(a.get_index(0, 1), M::T::from_f64(2.0).unwrap());
assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
assert_eq!(a.get_index(1, 1), M::T::from_f64(4.0).unwrap());
}
pub fn test_from_vec<M: DenseMatrix>() {
let a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 2);
assert_eq!(a.get_index(0, 0), f::<M>(1.0));
assert_eq!(a.get_index(1, 0), f::<M>(3.0));
assert_eq!(a.get_index(0, 1), f::<M>(2.0));
assert_eq!(a.get_index(1, 1), f::<M>(4.0));
}
pub fn test_gemm<M: DenseMatrix>() {
let a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![f::<M>(2.0), f::<M>(1.0), f::<M>(0.0), f::<M>(3.0)],
Default::default(),
);
let mut c = M::zeros(2, 2, Default::default());
c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
assert_eq!(c.get_index(0, 0), f::<M>(4.0));
assert_eq!(c.get_index(1, 0), f::<M>(10.0));
assert_eq!(c.get_index(0, 1), f::<M>(6.0));
assert_eq!(c.get_index(1, 1), f::<M>(12.0));
}
pub fn test_mat_mul<M: DenseMatrix>() {
let a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![f::<M>(2.0), f::<M>(1.0), f::<M>(0.0), f::<M>(3.0)],
Default::default(),
);
let c = a.mat_mul(&b);
assert_eq!(c.get_index(0, 0), f::<M>(4.0));
assert_eq!(c.get_index(1, 0), f::<M>(10.0));
assert_eq!(c.get_index(0, 1), f::<M>(6.0));
assert_eq!(c.get_index(1, 1), f::<M>(12.0));
}
pub fn test_columns_view<M: DenseMatrix>() {
let a = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(5.0),
f::<M>(3.0),
f::<M>(6.0),
],
Default::default(),
);
let view = a.columns(0, 2);
assert_eq!(view.ncols(), 2);
assert_eq!(view.nrows(), 2);
let owned = view.into_owned();
assert_eq!(owned.get_index(0, 0), f::<M>(1.0));
assert_eq!(owned.get_index(1, 0), f::<M>(4.0));
assert_eq!(owned.get_index(0, 1), f::<M>(2.0));
assert_eq!(owned.get_index(1, 1), f::<M>(5.0));
}
pub fn test_column_view<M: DenseMatrix>() {
let a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let col = a.column(1);
use crate::VectorView;
assert_eq!(col.get_index(0), f::<M>(2.0));
assert_eq!(col.get_index(1), f::<M>(4.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_zeros_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let a = M::zeros(2, 3, ctx);
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 3);
let vals = triplet_values(&a);
assert!(vals.is_empty() || vals.iter().all(|v| v.is_zero()));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemv_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0), f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0), ];
let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
let x = M::V::from_vec(
vec![f::<M>(1.0), f::<M>(2.0), f::<M>(1.0), f::<M>(1.0)],
ctx.clone(),
);
let mut y = M::V::zeros(2, ctx);
a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(5.0), f::<M>(11.0), f::<M>(11.0), f::<M>(15.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemv_broadcast_x_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0),
];
let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(2.0)], Default::default());
let mut y = M::V::zeros(2, ctx);
a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(5.0), f::<M>(11.0), f::<M>(17.0), f::<M>(23.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemv_broadcast_mat_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
let a =
M::try_from_triplets(2, 2, indices, values, ctx.clone_with_nbatch(1).unwrap()).unwrap();
let x = M::V::from_vec(
vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)],
ctx.clone(),
);
let mut y = M::V::zeros(2, ctx);
a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(5.0), f::<M>(11.0), f::<M>(11.0), f::<M>(25.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_from_diagonal_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let v = M::V::from_vec(
vec![f::<M>(2.0), f::<M>(3.0), f::<M>(4.0), f::<M>(5.0)],
ctx,
);
let a = M::from_diagonal(&v);
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 2);
let idx = triplet_indices(&a);
let vals = triplet_values(&a);
for &(i, j) in &idx {
let pos = idx.iter().position(|&x| x == (i, j)).unwrap();
if i == j {
assert!(
vals[pos] != M::T::zero(),
"diagonal entry should be non-zero"
);
} else {
assert!(vals[pos].is_zero(), "off-diagonal entry should be zero");
}
}
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_copy_from_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(8.0),
];
let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
let mut b = M::zeros(2, 2, ctx);
b.copy_from(&a);
let vals = triplet_values(&b);
assert_eq!(
vals,
vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(8.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_set_column_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
];
let mut a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
let v = M::V::from_vec(
vec![f::<M>(5.0), f::<M>(6.0), f::<M>(7.0), f::<M>(8.0)],
ctx,
);
a.set_column(0, &v);
let vals = triplet_values(&a);
assert_eq!(
vals,
vec![
f::<M>(5.0),
f::<M>(6.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(7.0),
f::<M>(8.0),
f::<M>(0.0),
f::<M>(0.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_scale_add_and_assign_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let x_vals = vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(8.0),
];
let y_vals = vec![
f::<M>(10.0),
f::<M>(20.0),
f::<M>(30.0),
f::<M>(40.0),
f::<M>(50.0),
f::<M>(60.0),
f::<M>(70.0),
f::<M>(80.0),
];
let x = M::try_from_triplets(2, 2, indices.clone(), x_vals, ctx.clone()).unwrap();
let y = M::try_from_triplets(2, 2, indices, y_vals, ctx.clone()).unwrap();
let mut result = M::zeros(2, 2, ctx);
result.copy_from(&x);
result.scale_add_and_assign(&x, f::<M>(2.0), &y);
let vals = triplet_values(&result);
assert_eq!(
vals,
vec![
f::<M>(21.0),
f::<M>(42.0),
f::<M>(63.0),
f::<M>(84.0),
f::<M>(105.0),
f::<M>(126.0),
f::<M>(147.0),
f::<M>(168.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_from_vec<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let a = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0),
],
ctx,
);
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 2);
assert_eq!(a.get_index(0, 0), f::<M>(1.0));
assert_eq!(a.get_index(1, 0), f::<M>(3.0));
assert_eq!(a.get_index(0, 1), f::<M>(2.0));
assert_eq!(a.get_index(1, 1), f::<M>(4.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let a = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(1.0),
f::<M>(2.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(2.0),
],
ctx.clone(),
);
let b = M::from_vec(
2,
2,
vec![
f::<M>(3.0),
f::<M>(5.0),
f::<M>(4.0),
f::<M>(6.0),
f::<M>(1.0),
f::<M>(1.0),
f::<M>(1.0),
f::<M>(1.0),
],
ctx.clone(),
);
let mut c = M::zeros(2, 2, ctx);
c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
assert_eq!(c.get_index(0, 0), f::<M>(3.0));
assert_eq!(c.get_index(1, 0), f::<M>(5.0));
assert_eq!(c.get_index(0, 1), f::<M>(4.0));
assert_eq!(c.get_index(1, 1), f::<M>(6.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_columns<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let a = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(8.0),
f::<M>(9.0),
f::<M>(10.0),
f::<M>(11.0),
f::<M>(12.0),
],
ctx.clone(),
);
let view = a.columns(0, 2);
assert_eq!(view.ncols(), 2);
assert_eq!(view.nrows(), 2);
let owned = view.into_owned();
assert_eq!(owned.nrows(), 2);
assert_eq!(owned.ncols(), 2);
let view2 = a.columns(0, 2);
let x = M::V::from_vec(
vec![f::<M>(1.0), f::<M>(1.0), f::<M>(1.0), f::<M>(1.0)],
ctx.clone(),
);
let mut y = M::V::zeros(2, ctx);
view2.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(4.0), f::<M>(6.0), f::<M>(16.0), f::<M>(18.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemv_o_on_columns<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let diff = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(5.0),
f::<M>(3.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(10.0),
f::<M>(8.0),
f::<M>(11.0),
f::<M>(9.0),
f::<M>(12.0),
],
ctx.clone(),
);
let view = diff.columns(0, 2);
let x = M::V::from_vec(
vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
ctx.clone(),
);
let mut y = M::V::zeros(2, ctx);
view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(3.0), f::<M>(9.0), f::<M>(30.0), f::<M>(42.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemv_v_broadcast_mat<M: DenseMatrix>(ctx3: M::C) {
assert_eq!(ctx3.nbatch(), 2);
let ctx1 = M::C::default();
let diff = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(5.0),
f::<M>(3.0),
f::<M>(6.0),
],
ctx1,
);
let view = diff.columns(0, 2);
let x = M::V::from_vec(
vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
ctx3.clone(),
);
let mut y = M::V::zeros(2, ctx3);
view.gemv_v(f::<M>(1.0), &x.as_view(), f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(3.0), f::<M>(9.0), f::<M>(6.0), f::<M>(18.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemv_o_broadcast_mat<M: DenseMatrix>(ctx3: M::C) {
assert_eq!(ctx3.nbatch(), 2);
let ctx1 = M::C::default();
let diff = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(5.0),
f::<M>(3.0),
f::<M>(6.0),
],
ctx1,
);
let view = diff.columns(0, 2);
let x = M::V::from_vec(
vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
ctx3.clone(),
);
let mut y = M::V::zeros(2, ctx3);
view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(3.0), f::<M>(9.0), f::<M>(6.0), f::<M>(18.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm_vo_on_columns<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let diff = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(5.0),
f::<M>(3.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(10.0),
f::<M>(8.0),
f::<M>(11.0),
f::<M>(9.0),
f::<M>(12.0),
],
ctx.clone(),
);
let r = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(1.0),
f::<M>(2.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(2.0),
],
ctx.clone(),
);
let mut result = M::zeros(2, 3, ctx);
{
let d_view = diff.columns(0, 2);
let mut r_view = result.columns_mut(0, 2);
r_view.gemm_vo(f::<M>(1.0), &d_view, &r, f::<M>(0.0));
}
assert_eq!(result.get_index(0, 0), f::<M>(1.0));
assert_eq!(result.get_index(1, 0), f::<M>(4.0));
assert_eq!(result.get_index(0, 1), f::<M>(2.0));
assert_eq!(result.get_index(1, 1), f::<M>(5.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm_broadcast_b<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let a = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(1.0),
f::<M>(2.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(3.0),
],
ctx.clone(),
);
let b = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let mut c = M::zeros(2, 2, ctx);
c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
assert_eq!(c.get_index(0, 0), f::<M>(1.0));
assert_eq!(c.get_index(1, 0), f::<M>(3.0));
assert_eq!(c.get_index(0, 1), f::<M>(2.0));
assert_eq!(c.get_index(1, 1), f::<M>(4.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm_broadcast_a<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(2.0)],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![
f::<M>(3.0),
f::<M>(5.0),
f::<M>(4.0),
f::<M>(6.0),
f::<M>(1.0),
f::<M>(1.0),
f::<M>(1.0),
f::<M>(1.0),
],
ctx.clone(),
);
let mut c = M::zeros(2, 2, ctx);
c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
assert_eq!(c.get_index(0, 0), f::<M>(3.0));
assert_eq!(c.get_index(1, 0), f::<M>(10.0));
assert_eq!(c.get_index(0, 1), f::<M>(4.0));
assert_eq!(c.get_index(1, 1), f::<M>(12.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemv_o_broadcast_x<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let diff = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(5.0),
f::<M>(3.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(10.0),
f::<M>(8.0),
f::<M>(11.0),
f::<M>(9.0),
f::<M>(12.0),
],
ctx.clone(),
);
let view = diff.columns(0, 2);
let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(1.0)], Default::default());
let mut y = M::V::zeros(2, ctx);
view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(3.0), f::<M>(9.0), f::<M>(15.0), f::<M>(21.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm_vo_broadcast_b<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let diff = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(5.0),
f::<M>(3.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(10.0),
f::<M>(8.0),
f::<M>(11.0),
f::<M>(9.0),
f::<M>(12.0),
],
ctx.clone(),
);
let r = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(1.0)],
Default::default(),
);
let mut result = M::zeros(2, 3, ctx);
{
let d_view = diff.columns(0, 2);
let mut r_view = result.columns_mut(0, 2);
r_view.gemm_vo(f::<M>(1.0), &d_view, &r, f::<M>(0.0));
}
assert_eq!(result.get_index(0, 0), f::<M>(1.0));
assert_eq!(result.get_index(1, 0), f::<M>(4.0));
assert_eq!(result.get_index(0, 1), f::<M>(2.0));
assert_eq!(result.get_index(1, 1), f::<M>(5.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm_vo_broadcast_a<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let diff = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(5.0),
f::<M>(3.0),
f::<M>(6.0),
],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(1.0),
f::<M>(2.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(3.0),
],
ctx.clone(),
);
let mut result = M::zeros(2, 3, ctx);
{
let d_view = diff.columns(0, 2);
let mut r_view = result.columns_mut(0, 2);
r_view.gemm_vo(f::<M>(1.0), &d_view, &b, f::<M>(0.0));
}
assert_eq!(result.get_index(0, 0), f::<M>(1.0));
assert_eq!(result.get_index(1, 0), f::<M>(4.0));
assert_eq!(result.get_index(0, 1), f::<M>(2.0));
assert_eq!(result.get_index(1, 1), f::<M>(5.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm_incompatible_a<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
assert_eq!(ctx2.nbatch(), 2);
assert_eq!(ctx3.nbatch(), 3);
let a = M::zeros(2, 2, ctx3);
let b = M::zeros(2, 2, ctx2.clone());
let mut c = M::zeros(2, 2, ctx2);
c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemv_incompatible<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
assert_eq!(ctx2.nbatch(), 2);
assert_eq!(ctx3.nbatch(), 3);
let a = M::zeros(2, 2, ctx2.clone());
let x = M::V::zeros(2, ctx3);
let mut y = M::V::zeros(2, ctx2);
a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm_incompatible<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
assert_eq!(ctx2.nbatch(), 2);
assert_eq!(ctx3.nbatch(), 3);
let a = M::zeros(2, 2, ctx2.clone());
let b = M::zeros(2, 2, ctx3);
let mut c = M::zeros(2, 2, ctx2);
c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_resize_cols<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let mut a = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0),
],
ctx.clone(),
);
a.resize_cols(3);
assert_eq!(a.ncols(), 3);
assert_eq!(a.nrows(), 2);
assert_eq!(a.get_index(0, 0), f::<M>(1.0));
assert_eq!(a.get_index(1, 0), f::<M>(3.0));
assert_eq!(a.get_index(0, 1), f::<M>(2.0));
assert_eq!(a.get_index(1, 1), f::<M>(4.0));
assert_eq!(a.get_index(0, 2), f::<M>(0.0));
assert_eq!(a.get_index(1, 2), f::<M>(0.0));
let x = M::V::from_vec(
vec![
f::<M>(1.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(1.0),
f::<M>(0.0),
f::<M>(0.0),
],
ctx.clone(),
);
let mut y = M::V::zeros(2, ctx.clone());
a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
assert_eq!(
y.clone_as_vec(),
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(5.0), f::<M>(7.0)]
);
a.resize_cols(1);
assert_eq!(a.ncols(), 1);
assert_eq!(a.get_index(0, 0), f::<M>(1.0));
assert_eq!(a.get_index(1, 0), f::<M>(3.0));
let x2 = M::V::from_vec(vec![f::<M>(1.0), f::<M>(1.0)], ctx.clone());
let mut y2 = M::V::zeros(2, ctx);
a.gemv(f::<M>(1.0), &x2, f::<M>(0.0), &mut y2);
assert_eq!(
y2.clone_as_vec(),
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(5.0), f::<M>(7.0)]
);
}
pub fn test_mul_scalar<M: Matrix>() {
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
let result = a * Scale(f::<M>(2.0));
let (_, vals) = result.triplet_iter();
let vals: Vec<_> = vals.collect();
assert_eq!(
vals,
vec![f::<M>(2.0), f::<M>(6.0), f::<M>(4.0), f::<M>(8.0)]
);
}
pub fn test_add_column_to_vector<M: Matrix>() {
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
let mat = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
let mut v = M::V::zeros(2, Default::default());
mat.add_column_to_vector(1, &mut v);
assert_eq!(v.clone_as_vec(), vec![f::<M>(3.0), f::<M>(4.0)]);
}
pub fn test_add<M: DenseMatrix>() {
let a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
Default::default(),
);
let result = a + &b;
assert_eq!(result.get_index(0, 0), f::<M>(6.0));
assert_eq!(result.get_index(1, 1), f::<M>(12.0));
}
pub fn test_sub<M: DenseMatrix>() {
let a = M::from_vec(
2,
2,
vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let result = a - &b;
assert_eq!(result.get_index(0, 0), f::<M>(4.0));
assert_eq!(result.get_index(1, 1), f::<M>(4.0));
}
pub fn test_add_assign<M: DenseMatrix>() {
let mut a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
Default::default(),
);
a += &b;
assert_eq!(a.get_index(0, 0), f::<M>(6.0));
assert_eq!(a.get_index(1, 1), f::<M>(12.0));
}
pub fn test_sub_assign<M: DenseMatrix>() {
let mut a = M::from_vec(
2,
2,
vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
a -= &b;
assert_eq!(a.get_index(0, 0), f::<M>(4.0));
assert_eq!(a.get_index(1, 1), f::<M>(4.0));
}
pub fn test_gather<M: DenseMatrix>() {
let mat1 = M::from_vec(
3,
3,
vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(8.0),
f::<M>(9.0),
],
Default::default(),
);
let mut mat2 = M::zeros(2, 2, Default::default());
let indices = <M::V as Vector>::Index::from_vec(vec![0, 1, 3, 4], Default::default());
mat2.gather(&mat1, &indices);
assert_eq!(mat2.get_index(0, 0), f::<M>(1.0));
assert_eq!(mat2.get_index(1, 0), f::<M>(2.0));
assert_eq!(mat2.get_index(0, 1), f::<M>(4.0));
assert_eq!(mat2.get_index(1, 1), f::<M>(5.0));
}
pub fn test_set_data_with_indices<M: DenseMatrix>() {
let mut mat = M::zeros(2, 2, Default::default());
let dst_indices = <M::V as Vector>::Index::from_vec(vec![0, 3], Default::default());
let src_indices = <M::V as Vector>::Index::from_vec(vec![0, 1], Default::default());
let data = M::V::from_vec(vec![f::<M>(5.0), f::<M>(6.0)], Default::default());
mat.set_data_with_indices(&dst_indices, &src_indices, &data);
assert_eq!(mat.get_index(0, 0), f::<M>(5.0));
assert_eq!(mat.get_index(1, 1), f::<M>(6.0));
}
pub fn test_mul_assign_scalar<M: DenseMatrix>() {
let mut mat = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
{
let mut view = mat.columns_mut(0, 2);
view *= Scale(f::<M>(2.0));
}
assert_eq!(mat.get_index(0, 0), f::<M>(2.0));
assert_eq!(mat.get_index(1, 1), f::<M>(8.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_combine<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
#[rustfmt::skip]
let data: Vec<M::T> = vec![
f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0),
f::<M>(5.0), f::<M>(6.0), f::<M>(7.0), f::<M>(8.0),
f::<M>(9.0), f::<M>(10.0), f::<M>(11.0), f::<M>(12.0),
f::<M>(13.0), f::<M>(14.0), f::<M>(15.0), f::<M>(16.0),
f::<M>(101.0), f::<M>(102.0), f::<M>(103.0), f::<M>(104.0),
f::<M>(105.0), f::<M>(106.0), f::<M>(107.0), f::<M>(108.0),
f::<M>(109.0), f::<M>(110.0), f::<M>(111.0), f::<M>(112.0),
f::<M>(113.0), f::<M>(114.0), f::<M>(115.0), f::<M>(116.0),
];
let m = M::from_vec(4, 4, data, ctx.clone());
let alg_indices = <M::V as Vector>::Index::from_vec(vec![1, 3], Default::default());
let [(ul, _), (ur, _), (ll, _), (lr, _)] = m.split(&alg_indices);
let recombined = M::combine(&ul, &ur, &ll, &lr, &alg_indices);
let (_orig_idx, orig_vals) = m.triplet_iter();
let (_recom_idx, recom_vals) = recombined.triplet_iter();
let orig_vals: Vec<_> = orig_vals.collect();
let recom_vals: Vec<_> = recom_vals.collect();
assert_eq!(orig_vals, recom_vals);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_add_column_to_vector_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(8.0),
];
let mat = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
let mut v = M::V::zeros(2, ctx);
mat.add_column_to_vector(1, &mut v);
assert_eq!(
v.clone_as_vec(),
vec![f::<M>(3.0), f::<M>(4.0), f::<M>(7.0), f::<M>(8.0)]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_set_data_with_indices_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let zero_values = vec![
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
];
let mut mat = M::try_from_triplets(2, 2, indices, zero_values, ctx.clone()).unwrap();
let dst_indices = <M::V as Vector>::Index::from_vec(vec![0, 3], Default::default());
let src_indices = <M::V as Vector>::Index::from_vec(vec![0, 1], Default::default());
let data = M::V::from_vec(
vec![f::<M>(5.0), f::<M>(6.0), f::<M>(50.0), f::<M>(60.0)],
ctx,
);
mat.set_data_with_indices(&dst_indices, &src_indices, &data);
let (_, vals) = mat.triplet_iter();
let vals: Vec<_> = vals.collect();
assert_eq!(
vals,
vec![
f::<M>(5.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(6.0),
f::<M>(50.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(60.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gather_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices: Vec<(IndexType, IndexType)> =
(0..3).flat_map(|j| (0..3).map(move |i| (i, j))).collect();
let values = vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
f::<M>(7.0),
f::<M>(8.0),
f::<M>(9.0),
f::<M>(10.0),
f::<M>(20.0),
f::<M>(30.0),
f::<M>(40.0),
f::<M>(50.0),
f::<M>(60.0),
f::<M>(70.0),
f::<M>(80.0),
f::<M>(90.0),
];
let mat1 = M::try_from_triplets(3, 3, indices, values, ctx.clone()).unwrap();
let dest_indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let zero_values = vec![
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(0.0),
];
let mut mat2 = M::try_from_triplets(2, 2, dest_indices, zero_values, ctx).unwrap();
let gather_indices =
<M::V as Vector>::Index::from_vec(vec![0, 1, 3, 4], Default::default());
mat2.gather(&mat1, &gather_indices);
let (_, vals) = mat2.triplet_iter();
let vals: Vec<_> = vals.collect();
assert_eq!(
vals,
vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(10.0),
f::<M>(20.0),
f::<M>(40.0),
f::<M>(50.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_mul_scalar_m<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0),
];
let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
let result = a * Scale(f::<M>(2.0));
let (_, vals) = result.triplet_iter();
let vals: Vec<_> = vals.collect();
assert_eq!(
vals,
vec![
f::<M>(2.0),
f::<M>(6.0),
f::<M>(4.0),
f::<M>(8.0),
f::<M>(10.0),
f::<M>(14.0),
f::<M>(12.0),
f::<M>(16.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_partition_indices<M: Matrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let zero_val = M::T::zero();
let one_val = f::<M>(1.0);
let two_val = f::<M>(2.0);
let indices = vec![(0, 0), (1, 1), (2, 2)];
let values = vec![one_val, zero_val, one_val, two_val, zero_val, two_val];
let m = M::try_from_triplets(3, 3, indices, values, ctx).unwrap();
let (zero_idx, nonzero_idx) = m.partition_indices_by_zero_diagonal();
assert_eq!(zero_idx.clone_as_vec(), vec![1]);
assert_eq!(nonzero_idx.clone_as_vec(), vec![0, 2]);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_column_axpy<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let mut a = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0),
],
ctx,
);
a.column_axpy(f::<M>(2.0), 0, 1);
assert_eq!(a.get_index(0, 0), f::<M>(1.0));
assert_eq!(a.get_index(0, 1), f::<M>(4.0));
assert_eq!(a.get_index(1, 0), f::<M>(3.0));
assert_eq!(a.get_index(1, 1), f::<M>(10.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_mat_mul<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let a = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(2.0),
f::<M>(1.0),
f::<M>(0.0),
f::<M>(3.0),
],
ctx.clone(),
);
let b = M::from_vec(
2,
2,
vec![
f::<M>(2.0),
f::<M>(1.0),
f::<M>(0.0),
f::<M>(3.0),
f::<M>(1.0),
f::<M>(0.0),
f::<M>(2.0),
f::<M>(1.0),
],
ctx.clone(),
);
let c = a.mat_mul(&b);
assert_eq!(c.get_index(0, 0), f::<M>(4.0));
assert_eq!(c.get_index(1, 0), f::<M>(10.0));
assert_eq!(c.get_index(0, 1), f::<M>(6.0));
assert_eq!(c.get_index(1, 1), f::<M>(12.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_from_diagonal_dense<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let v = M::V::from_vec(
vec![f::<M>(2.0), f::<M>(3.0), f::<M>(4.0), f::<M>(5.0)],
ctx,
);
let a = M::from_diagonal(&v);
assert_eq!(a.nrows(), 2);
assert_eq!(a.ncols(), 2);
assert_eq!(a.get_index(0, 0), f::<M>(2.0));
assert_eq!(a.get_index(1, 1), f::<M>(3.0));
assert_eq!(a.get_index(0, 1), f::<M>(0.0));
assert_eq!(a.get_index(1, 0), f::<M>(0.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
fn make_strided_matrix<M: DenseMatrix>(nbatch: usize) -> M {
let ctx = M::C::default().clone_with_nbatch(nbatch).unwrap();
let nrows = 3;
let ncols = 4;
let mut data = Vec::with_capacity(nrows * ncols * nbatch);
for b in 0..nbatch {
for col in 0..ncols {
for row in 0..nrows {
data.push(f::<M>(row as f64 + col as f64 * 10.0 + b as f64 * 100.0));
}
}
}
M::from_vec(nrows, ncols, data, ctx)
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_strided_matrix_view_into_owned<M: DenseMatrix>(ctx: M::C) {
let matrix = make_strided_matrix::<M>(ctx.nbatch());
let view = matrix.columns(0, 2);
let owned = view.into_owned();
assert_eq!(owned.nrows(), 3);
assert_eq!(owned.ncols(), 2);
assert_eq!(owned.get_index(0, 0), f::<M>(0.0));
assert_eq!(owned.get_index(1, 0), f::<M>(1.0));
assert_eq!(owned.get_index(2, 0), f::<M>(2.0));
assert_eq!(owned.get_index(0, 1), f::<M>(10.0));
assert_eq!(owned.get_index(1, 1), f::<M>(11.0));
assert_eq!(owned.get_index(2, 1), f::<M>(12.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_strided_matrix_view_add_owned<M: DenseMatrix>(ctx: M::C) {
let matrix = make_strided_matrix::<M>(ctx.nbatch());
let view = matrix.columns(0, 2);
let rhs = M::from_vec(
3,
2,
vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
],
M::C::default(),
);
let result = view + &rhs;
assert_eq!(result.get_index(0, 0), f::<M>(1.0));
assert_eq!(result.get_index(0, 1), f::<M>(14.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_strided_matrix_view_sub_owned<M: DenseMatrix>(ctx: M::C) {
let matrix = make_strided_matrix::<M>(ctx.nbatch());
let view = matrix.columns(0, 2);
let rhs = M::from_vec(
3,
2,
vec![
f::<M>(0.0),
f::<M>(1.0),
f::<M>(2.0),
f::<M>(10.0),
f::<M>(11.0),
f::<M>(12.0),
],
M::C::default(),
);
let result = view - &rhs;
assert_eq!(result.get_index(0, 0), f::<M>(0.0));
assert_eq!(result.get_index(0, 1), f::<M>(0.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_strided_matrix_view_mul_scalar<M: DenseMatrix>(ctx: M::C) {
let matrix = make_strided_matrix::<M>(ctx.nbatch());
let view = matrix.columns(0, 2);
let result = view * Scale(f::<M>(2.0));
assert_eq!(result.get_index(0, 0), f::<M>(0.0));
assert_eq!(result.get_index(1, 0), f::<M>(2.0));
assert_eq!(result.get_index(0, 1), f::<M>(20.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_strided_matrix_view_mut_add_assign_view<M: DenseMatrix>(ctx: M::C) {
let mut a = make_strided_matrix::<M>(ctx.nbatch());
let b = make_strided_matrix::<M>(ctx.nbatch());
{
let mut a_view = a.columns_mut(0, 2);
let b_view = b.columns(2, 4);
a_view += &b_view;
}
assert_eq!(a.get_index(0, 0), f::<M>(20.0));
assert_eq!(a.get_index(1, 0), f::<M>(22.0));
assert_eq!(a.get_index(0, 1), f::<M>(40.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_strided_matrix_view_mut_sub_assign_view<M: DenseMatrix>(ctx: M::C) {
let mut a = make_strided_matrix::<M>(ctx.nbatch());
let b = make_strided_matrix::<M>(ctx.nbatch());
{
let mut a_view = a.columns_mut(0, 2);
let b_view = b.columns(0, 2);
a_view -= &b_view;
}
assert_eq!(a.get_index(0, 0), f::<M>(0.0));
assert_eq!(a.get_index(1, 0), f::<M>(0.0));
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_strided_matrix_view_mut_mul_assign_scalar<M: DenseMatrix>(ctx: M::C) {
let mut a = make_strided_matrix::<M>(ctx.nbatch());
{
let mut a_view = a.columns_mut(0, 2);
a_view *= Scale(f::<M>(2.0));
}
assert_eq!(a.get_index(0, 0), f::<M>(0.0));
assert_eq!(a.get_index(1, 0), f::<M>(2.0));
assert_eq!(a.get_index(0, 1), f::<M>(20.0));
}
pub fn test_view_mut_into_owned<M: DenseMatrix>() {
let mut a = M::from_vec(
2,
3,
vec![
f::<M>(1.0),
f::<M>(2.0),
f::<M>(3.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(6.0),
],
Default::default(),
);
let owned = a.columns_mut(0, 2).into_owned();
assert_eq!(owned.nrows(), 2);
assert_eq!(owned.ncols(), 2);
assert_eq!(owned.get_index(0, 0), f::<M>(1.0));
assert_eq!(owned.get_index(1, 0), f::<M>(2.0));
assert_eq!(owned.get_index(0, 1), f::<M>(3.0));
assert_eq!(owned.get_index(1, 1), f::<M>(4.0));
}
pub fn test_view_mut_add_assign_view_mut<M: DenseMatrix>() {
let mut a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let mut b = M::from_vec(
2,
2,
vec![f::<M>(10.0), f::<M>(30.0), f::<M>(20.0), f::<M>(40.0)],
Default::default(),
);
{
let mut a_view = a.columns_mut(0, 2);
let b_view = b.columns_mut(0, 2);
a_view += &b_view;
}
assert_eq!(a.get_index(0, 0), f::<M>(11.0));
assert_eq!(a.get_index(1, 0), f::<M>(33.0));
assert_eq!(a.get_index(0, 1), f::<M>(22.0));
assert_eq!(a.get_index(1, 1), f::<M>(44.0));
}
pub fn test_view_mut_sub_assign_view_mut<M: DenseMatrix>() {
let mut a = M::from_vec(
2,
2,
vec![f::<M>(10.0), f::<M>(30.0), f::<M>(20.0), f::<M>(40.0)],
Default::default(),
);
let mut b = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
{
let mut a_view = a.columns_mut(0, 2);
let b_view = b.columns_mut(0, 2);
a_view -= &b_view;
}
assert_eq!(a.get_index(0, 0), f::<M>(9.0));
assert_eq!(a.get_index(1, 0), f::<M>(27.0));
assert_eq!(a.get_index(0, 1), f::<M>(18.0));
assert_eq!(a.get_index(1, 1), f::<M>(36.0));
}
pub fn test_gemm_oo_on_columns<M: DenseMatrix>() {
let a = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
Default::default(),
);
let b = M::from_vec(
2,
2,
vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(1.0)],
Default::default(),
);
let mut result = M::zeros(2, 3, Default::default());
{
let mut r_view = result.columns_mut(0, 2);
r_view.gemm_oo(f::<M>(1.0), &a, &b, f::<M>(0.0));
}
assert_eq!(result.get_index(0, 0), f::<M>(1.0));
assert_eq!(result.get_index(1, 0), f::<M>(3.0));
assert_eq!(result.get_index(0, 1), f::<M>(2.0));
assert_eq!(result.get_index(1, 1), f::<M>(4.0));
assert_eq!(result.get_index(0, 2), f::<M>(0.0));
assert_eq!(result.get_index(1, 2), f::<M>(0.0));
}
pub fn test_try_from_triplets_wrong_length<M: Matrix>() {
let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0)];
let _ = M::try_from_triplets(2, 2, indices, values, Default::default());
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_strided_matrix_view_mut_into_owned<M: DenseMatrix>(ctx: M::C) {
let mut matrix = make_strided_matrix::<M>(ctx.nbatch());
let owned = matrix.columns_mut(0, 2).into_owned();
assert_eq!(owned.nrows(), 3);
assert_eq!(owned.ncols(), 2);
assert_eq!(owned.get_index(0, 0), f::<M>(0.0));
assert_eq!(owned.get_index(1, 0), f::<M>(1.0));
assert_eq!(owned.get_index(2, 0), f::<M>(2.0));
assert_eq!(owned.get_index(0, 1), f::<M>(10.0));
assert_eq!(owned.get_index(1, 1), f::<M>(11.0));
assert_eq!(owned.get_index(2, 1), f::<M>(12.0));
let (_, vals) = owned.triplet_iter();
let vals: Vec<_> = vals.collect();
assert_eq!(
vals,
vec![
f::<M>(0.0),
f::<M>(1.0),
f::<M>(2.0),
f::<M>(10.0),
f::<M>(11.0),
f::<M>(12.0),
f::<M>(100.0),
f::<M>(101.0),
f::<M>(102.0),
f::<M>(110.0),
f::<M>(111.0),
f::<M>(112.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_view_mut_add_assign_view_mut<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let mut a = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0),
],
ctx.clone(),
);
let mut b = M::from_vec(
2,
2,
vec![
f::<M>(10.0),
f::<M>(30.0),
f::<M>(20.0),
f::<M>(40.0),
f::<M>(50.0),
f::<M>(70.0),
f::<M>(60.0),
f::<M>(80.0),
],
ctx,
);
{
let mut a_view = a.columns_mut(0, 2);
let b_view = b.columns_mut(0, 2);
a_view += &b_view;
}
let (_, vals) = a.triplet_iter();
let vals: Vec<_> = vals.collect();
assert_eq!(
vals,
vec![
f::<M>(11.0),
f::<M>(33.0),
f::<M>(22.0),
f::<M>(44.0),
f::<M>(55.0),
f::<M>(77.0),
f::<M>(66.0),
f::<M>(88.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_view_mut_sub_assign_view_mut<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let mut a = M::from_vec(
2,
2,
vec![
f::<M>(10.0),
f::<M>(30.0),
f::<M>(20.0),
f::<M>(40.0),
f::<M>(50.0),
f::<M>(70.0),
f::<M>(60.0),
f::<M>(80.0),
],
ctx.clone(),
);
let mut b = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0),
],
ctx,
);
{
let mut a_view = a.columns_mut(0, 2);
let b_view = b.columns_mut(0, 2);
a_view -= &b_view;
}
let (_, vals) = a.triplet_iter();
let vals: Vec<_> = vals.collect();
assert_eq!(
vals,
vec![
f::<M>(9.0),
f::<M>(27.0),
f::<M>(18.0),
f::<M>(36.0),
f::<M>(45.0),
f::<M>(63.0),
f::<M>(54.0),
f::<M>(72.0),
]
);
}
#[cfg_attr(not(feature = "cuda"), allow(dead_code))]
pub fn test_batched_gemm_oo_on_columns<M: DenseMatrix>(ctx: M::C) {
assert_eq!(ctx.nbatch(), 2);
let a = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(5.0),
f::<M>(7.0),
f::<M>(6.0),
f::<M>(8.0),
],
ctx.clone(),
);
let b = M::from_vec(
2,
2,
vec![
f::<M>(1.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(1.0),
f::<M>(2.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(2.0),
],
ctx.clone(),
);
let mut result = M::zeros(2, 3, ctx);
{
let mut r_view = result.columns_mut(0, 2);
r_view.gemm_oo(f::<M>(1.0), &a, &b, f::<M>(0.0));
}
assert_eq!(result.get_index(0, 0), f::<M>(1.0));
assert_eq!(result.get_index(1, 0), f::<M>(3.0));
assert_eq!(result.get_index(0, 1), f::<M>(2.0));
assert_eq!(result.get_index(1, 1), f::<M>(4.0));
let (_, vals) = result.triplet_iter();
let vals: Vec<_> = vals.collect();
assert_eq!(
vals,
vec![
f::<M>(1.0),
f::<M>(3.0),
f::<M>(2.0),
f::<M>(4.0),
f::<M>(0.0),
f::<M>(0.0),
f::<M>(10.0),
f::<M>(14.0),
f::<M>(12.0),
f::<M>(16.0),
f::<M>(0.0),
f::<M>(0.0),
]
);
}
}
#[cfg(test)]
macro_rules! generate_matrix_tests_nonbatched {
($suffix:ident, $M:ty) => {
paste::paste! {
#[test]
fn [<test_zeros_ $suffix>]() {
$crate::matrix::tests::test_zeros::<$M>();
}
#[test]
fn [<test_from_diagonal_ $suffix>]() {
$crate::matrix::tests::test_from_diagonal::<$M>();
}
#[test]
fn [<test_gemv_ $suffix>]() {
$crate::matrix::tests::test_gemv::<$M>();
}
#[test]
fn [<test_set_column_ $suffix>]() {
$crate::matrix::tests::test_set_column::<$M>();
}
#[test]
fn [<test_copy_from_ $suffix>]() {
$crate::matrix::tests::test_copy_from::<$M>();
}
#[test]
fn [<test_scale_add_and_assign_ $suffix>]() {
$crate::matrix::tests::test_scale_add_and_assign::<$M>();
}
#[test]
fn [<test_partition_indices_ $suffix>]() {
$crate::matrix::tests::test_partition_indices_by_zero_diagonal::<$M>();
}
#[test]
fn [<test_mul_scalar_ $suffix>]() {
$crate::matrix::tests::test_mul_scalar::<$M>();
}
#[test]
fn [<test_add_column_to_vector_ $suffix>]() {
$crate::matrix::tests::test_add_column_to_vector::<$M>();
}
#[test]
#[should_panic]
fn [<test_try_from_triplets_wrong_length_ $suffix>]() {
$crate::matrix::tests::test_try_from_triplets_wrong_length::<$M>();
}
}
};
}
#[cfg(test)]
#[cfg_attr(not(feature = "cuda"), allow(unused_macros))]
macro_rules! generate_matrix_tests_batched {
($suffix:ident, $M:ty, $ctx1:expr, $ctx2:expr) => {
paste::paste! {
#[test]
fn [<test_batched_add_column_to_vector_ $suffix>]() {
$crate::matrix::tests::test_batched_add_column_to_vector_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_set_data_with_indices_ $suffix>]() {
$crate::matrix::tests::test_batched_set_data_with_indices_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_gather_ $suffix>]() {
$crate::matrix::tests::test_batched_gather_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_mul_scalar_ $suffix>]() {
$crate::matrix::tests::test_batched_mul_scalar_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_partition_indices_ $suffix>]() {
$crate::matrix::tests::test_batched_partition_indices::<$M>($ctx2);
}
#[test]
fn [<test_batched_zeros_ $suffix>]() {
$crate::matrix::tests::test_batched_zeros_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemv_ $suffix>]() {
$crate::matrix::tests::test_batched_gemv_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemv_broadcast_x_ $suffix>]() {
$crate::matrix::tests::test_batched_gemv_broadcast_x_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemv_broadcast_mat_ $suffix>]() {
$crate::matrix::tests::test_batched_gemv_broadcast_mat_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_from_diagonal_ $suffix>]() {
$crate::matrix::tests::test_batched_from_diagonal_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_copy_from_ $suffix>]() {
$crate::matrix::tests::test_batched_copy_from_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_set_column_ $suffix>]() {
$crate::matrix::tests::test_batched_set_column_m::<$M>($ctx2);
}
#[test]
fn [<test_batched_scale_add_ $suffix>]() {
$crate::matrix::tests::test_batched_scale_add_and_assign_m::<$M>($ctx2);
}
}
};
}
#[cfg(test)]
macro_rules! generate_dense_matrix_tests_nonbatched {
($suffix:ident, $M:ty) => {
paste::paste! {
#[test]
fn [<test_from_vec_ $suffix>]() {
$crate::matrix::tests::test_from_vec::<$M>();
}
#[test]
fn [<test_from_diagonal_dense_ $suffix>]() {
$crate::matrix::tests::test_from_diagonal_dense::<$M>();
}
#[test]
fn [<test_gemm_ $suffix>]() {
$crate::matrix::tests::test_gemm::<$M>();
}
#[test]
fn [<test_mat_mul_ $suffix>]() {
$crate::matrix::tests::test_mat_mul::<$M>();
}
#[test]
fn [<test_columns_view_ $suffix>]() {
$crate::matrix::tests::test_columns_view::<$M>();
}
#[test]
fn [<test_column_view_ $suffix>]() {
$crate::matrix::tests::test_column_view::<$M>();
}
#[test]
fn [<test_column_axpy_ $suffix>]() {
$crate::matrix::tests::test_column_axpy::<$M>();
}
#[test]
fn [<test_resize_cols_ $suffix>]() {
$crate::matrix::tests::test_resize_cols::<$M>();
}
#[test]
fn [<test_add_ $suffix>]() {
$crate::matrix::tests::test_add::<$M>();
}
#[test]
fn [<test_sub_ $suffix>]() {
$crate::matrix::tests::test_sub::<$M>();
}
#[test]
fn [<test_add_assign_ $suffix>]() {
$crate::matrix::tests::test_add_assign::<$M>();
}
#[test]
fn [<test_sub_assign_ $suffix>]() {
$crate::matrix::tests::test_sub_assign::<$M>();
}
#[test]
fn [<test_gather_ $suffix>]() {
$crate::matrix::tests::test_gather::<$M>();
}
#[test]
fn [<test_set_data_with_indices_ $suffix>]() {
$crate::matrix::tests::test_set_data_with_indices::<$M>();
}
#[test]
fn [<test_mul_assign_scalar_ $suffix>]() {
$crate::matrix::tests::test_mul_assign_scalar::<$M>();
}
#[test]
fn [<test_view_mut_into_owned_ $suffix>]() {
$crate::matrix::tests::test_view_mut_into_owned::<$M>();
}
#[test]
fn [<test_view_mut_add_assign_view_mut_ $suffix>]() {
$crate::matrix::tests::test_view_mut_add_assign_view_mut::<$M>();
}
#[test]
fn [<test_view_mut_sub_assign_view_mut_ $suffix>]() {
$crate::matrix::tests::test_view_mut_sub_assign_view_mut::<$M>();
}
#[test]
fn [<test_gemm_oo_on_columns_ $suffix>]() {
$crate::matrix::tests::test_gemm_oo_on_columns::<$M>();
}
}
};
}
#[cfg(test)]
#[cfg_attr(not(feature = "cuda"), allow(unused_macros))]
macro_rules! generate_dense_matrix_tests_batched {
($suffix:ident, $M:ty, $ctx1:expr, $ctx2:expr) => {
paste::paste! {
#[test]
fn [<test_batched_column_axpy_ $suffix>]() {
$crate::matrix::tests::test_batched_column_axpy::<$M>($ctx2);
}
#[test]
fn [<test_batched_mat_mul_ $suffix>]() {
$crate::matrix::tests::test_batched_mat_mul::<$M>($ctx2);
}
#[test]
fn [<test_batched_from_diagonal_dense_ $suffix>]() {
$crate::matrix::tests::test_batched_from_diagonal_dense::<$M>($ctx2);
}
#[test]
fn [<test_batched_from_vec_ $suffix>]() {
$crate::matrix::tests::test_batched_from_vec::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemm_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm::<$M>($ctx2);
}
#[test]
fn [<test_batched_columns_ $suffix>]() {
$crate::matrix::tests::test_batched_columns::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemv_o_on_columns_ $suffix>]() {
$crate::matrix::tests::test_batched_gemv_o_on_columns::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemm_vo_on_columns_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm_vo_on_columns::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemm_broadcast_b_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm_broadcast_b::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemv_o_broadcast_x_ $suffix>]() {
$crate::matrix::tests::test_batched_gemv_o_broadcast_x::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemv_v_broadcast_mat_ $suffix>]() {
$crate::matrix::tests::test_batched_gemv_v_broadcast_mat::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemv_o_broadcast_mat_ $suffix>]() {
$crate::matrix::tests::test_batched_gemv_o_broadcast_mat::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemm_vo_broadcast_b_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm_vo_broadcast_b::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemm_broadcast_a_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm_broadcast_a::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemm_vo_broadcast_a_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm_vo_broadcast_a::<$M>($ctx2);
}
#[test]
fn [<test_batched_resize_cols_ $suffix>]() {
$crate::matrix::tests::test_batched_resize_cols::<$M>($ctx2);
}
#[test]
fn [<test_batched_combine_ $suffix>]() {
$crate::matrix::tests::test_batched_combine::<$M>($ctx2);
}
#[test]
#[should_panic(expected = "incompatible nbatch")]
fn [<test_batched_gemv_incompatible_ $suffix>]() {
$crate::matrix::tests::test_batched_gemv_incompatible::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
}
#[test]
#[should_panic(expected = "incompatible nbatch")]
fn [<test_batched_gemm_incompatible_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm_incompatible::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
}
#[test]
#[should_panic(expected = "incompatible nbatch")]
fn [<test_batched_gemm_incompatible_a_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm_incompatible_a::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
}
#[test]
fn [<test_strided_matrix_view_into_owned_ $suffix>]() {
$crate::matrix::tests::test_strided_matrix_view_into_owned::<$M>($ctx2);
}
#[test]
fn [<test_strided_matrix_view_add_owned_ $suffix>]() {
$crate::matrix::tests::test_strided_matrix_view_add_owned::<$M>($ctx2);
}
#[test]
fn [<test_strided_matrix_view_sub_owned_ $suffix>]() {
$crate::matrix::tests::test_strided_matrix_view_sub_owned::<$M>($ctx2);
}
#[test]
fn [<test_strided_matrix_view_mul_scalar_ $suffix>]() {
$crate::matrix::tests::test_strided_matrix_view_mul_scalar::<$M>($ctx2);
}
#[test]
fn [<test_strided_matrix_view_mut_add_assign_view_ $suffix>]() {
$crate::matrix::tests::test_strided_matrix_view_mut_add_assign_view::<$M>($ctx2);
}
#[test]
fn [<test_strided_matrix_view_mut_sub_assign_view_ $suffix>]() {
$crate::matrix::tests::test_strided_matrix_view_mut_sub_assign_view::<$M>($ctx2);
}
#[test]
fn [<test_strided_matrix_view_mut_mul_assign_scalar_ $suffix>]() {
$crate::matrix::tests::test_strided_matrix_view_mut_mul_assign_scalar::<$M>($ctx2);
}
#[test]
fn [<test_strided_matrix_view_mut_into_owned_ $suffix>]() {
$crate::matrix::tests::test_strided_matrix_view_mut_into_owned::<$M>($ctx2);
}
#[test]
fn [<test_batched_view_mut_add_assign_view_mut_ $suffix>]() {
$crate::matrix::tests::test_batched_view_mut_add_assign_view_mut::<$M>($ctx2);
}
#[test]
fn [<test_batched_view_mut_sub_assign_view_mut_ $suffix>]() {
$crate::matrix::tests::test_batched_view_mut_sub_assign_view_mut::<$M>($ctx2);
}
#[test]
fn [<test_batched_gemm_oo_on_columns_ $suffix>]() {
$crate::matrix::tests::test_batched_gemm_oo_on_columns::<$M>($ctx2);
}
}
};
}
#[cfg(test)]
#[cfg_attr(not(feature = "cuda"), allow(unused_imports))]
pub(crate) use generate_dense_matrix_tests_batched;
#[cfg(test)]
pub(crate) use generate_dense_matrix_tests_nonbatched;
#[cfg(test)]
#[cfg_attr(not(feature = "cuda"), allow(unused_imports))]
pub(crate) use generate_matrix_tests_batched;
#[cfg(test)]
pub(crate) use generate_matrix_tests_nonbatched;