use crate::csc::CscMatrix;
use crate::csr::CsrMatrix;
use crate::ops::Op;
use crate::ops::serial::{
spadd_csc_prealloc, spadd_csr_prealloc, spadd_pattern, spmm_csc_dense, spmm_csc_pattern,
spmm_csc_prealloc_unchecked, spmm_csr_dense, spmm_csr_pattern, spmm_csr_prealloc_unchecked,
};
use nalgebra::allocator::Allocator;
use nalgebra::base::storage::RawStorage;
use nalgebra::constraint::{DimEq, ShapeConstraint};
use nalgebra::{
ClosedAddAssign, ClosedDivAssign, ClosedMulAssign, ClosedSubAssign, DefaultAllocator, Dim, Dyn,
Matrix, OMatrix, Scalar, U1,
};
use num_traits::{One, Zero};
use std::ops::{Add, Div, DivAssign, Mul, MulAssign, Neg, Sub};
macro_rules! impl_bin_op {
($trait:ident, $method:ident,
<$($life:lifetime),* $(,)? $($scalar_type:ident $(: $bounds:path)?)?>($a:ident : $a_type:ty, $b:ident : $b_type:ty) -> $ret:ty $body:block)
=>
{
impl<$($life,)* $($scalar_type)?> $trait<$b_type> for $a_type
where
$($scalar_type: $($bounds + )? Scalar + ClosedAddAssign + ClosedSubAssign + ClosedMulAssign + Zero + One + Neg<Output=T>)?
{
type Output = $ret;
fn $method(self, $b: $b_type) -> Self::Output {
let $a = self;
$body
}
}
};
}
macro_rules! impl_sp_plus_minus {
($matrix_type:ident, $spadd_fn:ident, +) => {
impl_sp_plus_minus!(Add, add, $matrix_type, $spadd_fn, +, T::one());
};
($matrix_type:ident, $spadd_fn:ident, -) => {
impl_sp_plus_minus!(Sub, sub, $matrix_type, $spadd_fn, -, -T::one());
};
($trait:ident, $method:ident, $matrix_type:ident, $spadd_fn:ident, $sign:tt, $factor:expr) => {
impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let pattern = spadd_pattern(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()];
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap();
$spadd_fn(T::zero(), &mut result, T::one(), Op::NoOp(&a)).unwrap();
$spadd_fn(T::one(), &mut result, $factor, Op::NoOp(&b)).unwrap();
result
});
impl_bin_op!($trait, $method,
<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
&a $sign b
});
impl_bin_op!($trait, $method,
<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a $sign &b
});
impl_bin_op!($trait, $method, <T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> {
a $sign &b
});
}
}
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, +);
impl_sp_plus_minus!(CsrMatrix, spadd_csr_prealloc, -);
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, +);
impl_sp_plus_minus!(CscMatrix, spadd_csc_prealloc, -);
macro_rules! impl_mul {
($($args:tt)*) => {
impl_bin_op!(Mul, mul, $($args)*);
}
}
macro_rules! impl_spmm {
($matrix_type:ident, $pattern_fn:expr, $spmm_fn:expr) => {
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> {
let pattern = $pattern_fn(a.pattern(), b.pattern());
let values = vec![T::zero(); pattern.nnz()];
let mut result = $matrix_type::try_from_pattern_and_values(pattern, values)
.unwrap();
$spmm_fn(T::zero(),
&mut result,
T::one(),
Op::NoOp(a),
Op::NoOp(b))
.expect("Internal error: spmm failed (please debug).");
result
});
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { a * &b});
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a $matrix_type<T>) -> $matrix_type<T> { &a * b});
impl_mul!(<T>(a: $matrix_type<T>, b: $matrix_type<T>) -> $matrix_type<T> { &a * &b});
}
}
impl_spmm!(CsrMatrix, spmm_csr_pattern, spmm_csr_prealloc_unchecked);
impl_spmm!(CscMatrix, spmm_csc_pattern, spmm_csc_prealloc_unchecked);
macro_rules! impl_concrete_scalar_matrix_mul {
($matrix_type:ident, $($scalar_type:ty),*) => {
$(
impl_mul!(<>(a: $scalar_type, b: $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * a });
impl_mul!(<'a>(a: $scalar_type, b: &'a $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * a });
impl_mul!(<'a>(a: &'a $scalar_type, b: $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * (*a) });
impl_mul!(<'a>(a: &'a $scalar_type, b: &'a $matrix_type<$scalar_type>)
-> $matrix_type<$scalar_type> { b * *a });
)*
}
}
macro_rules! impl_scalar_mul {
($matrix_type: ident) => {
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
let values: Vec<_> = a.values()
.iter()
.map(|v_i| v_i.clone() * b.clone())
.collect();
$matrix_type::try_from_pattern_and_values(a.pattern().clone(), values).unwrap()
});
impl_mul!(<'a, T>(a: &'a $matrix_type<T>, b: T) -> $matrix_type<T> {
a * &b
});
impl_mul!(<'a, T>(a: $matrix_type<T>, b: &'a T) -> $matrix_type<T> {
let mut a = a;
for value in a.values_mut() {
*value = b.clone() * value.clone();
}
a
});
impl_mul!(<T>(a: $matrix_type<T>, b: T) -> $matrix_type<T> {
a * &b
});
impl_concrete_scalar_matrix_mul!(
$matrix_type,
i8, i16, i32, i64, isize, f32, f64);
impl<T> MulAssign<T> for $matrix_type<T>
where
T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One
{
fn mul_assign(&mut self, scalar: T) {
for val in self.values_mut() {
*val *= scalar.clone();
}
}
}
impl<'a, T> MulAssign<&'a T> for $matrix_type<T>
where
T: Scalar + ClosedAddAssign + ClosedMulAssign + Zero + One
{
fn mul_assign(&mut self, scalar: &'a T) {
for val in self.values_mut() {
*val *= scalar.clone();
}
}
}
}
}
impl_scalar_mul!(CsrMatrix);
impl_scalar_mul!(CscMatrix);
macro_rules! impl_neg {
($matrix_type:ident) => {
impl<T> Neg for $matrix_type<T>
where
T: Scalar + Neg<Output = T>,
{
type Output = $matrix_type<T>;
fn neg(mut self) -> Self::Output {
for v_i in self.values_mut() {
*v_i = -v_i.clone();
}
self
}
}
impl<'a, T> Neg for &'a $matrix_type<T>
where
T: Scalar + Neg<Output = T>,
{
type Output = $matrix_type<T>;
fn neg(self) -> Self::Output {
-self.clone()
}
}
};
}
impl_neg!(CsrMatrix);
impl_neg!(CscMatrix);
macro_rules! impl_div {
($matrix_type:ident) => {
impl_bin_op!(Div, div, <T: ClosedDivAssign>(matrix: $matrix_type<T>, scalar: T) -> $matrix_type<T> {
let mut matrix = matrix;
matrix /= scalar;
matrix
});
impl_bin_op!(Div, div, <'a, T: ClosedDivAssign>(matrix: $matrix_type<T>, scalar: &T) -> $matrix_type<T> {
matrix / scalar.clone()
});
impl_bin_op!(Div, div, <'a, T: ClosedDivAssign>(matrix: &'a $matrix_type<T>, scalar: T) -> $matrix_type<T> {
let new_values = matrix.values()
.iter()
.map(|v_i| v_i.clone() / scalar.clone())
.collect();
$matrix_type::try_from_pattern_and_values(matrix.pattern().clone(), new_values)
.unwrap()
});
impl_bin_op!(Div, div, <'a, T: ClosedDivAssign>(matrix: &'a $matrix_type<T>, scalar: &'a T) -> $matrix_type<T> {
matrix / scalar.clone()
});
impl<T> DivAssign<T> for $matrix_type<T>
where T : Scalar + ClosedAddAssign + ClosedMulAssign + ClosedDivAssign + Zero + One
{
fn div_assign(&mut self, scalar: T) {
self.values_mut().iter_mut().for_each(|v_i| *v_i /= scalar.clone());
}
}
impl<'a, T> DivAssign<&'a T> for $matrix_type<T>
where T : Scalar + ClosedAddAssign + ClosedMulAssign + ClosedDivAssign + Zero + One
{
fn div_assign(&mut self, scalar: &'a T) {
*self /= scalar.clone();
}
}
}
}
impl_div!(CsrMatrix);
impl_div!(CscMatrix);
macro_rules! impl_spmm_cs_dense {
($matrix_type_name:ident, $spmm_fn:ident) => {
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
let (_, ncols) = rhs.shape_generic();
let nrows = Dyn(lhs.nrows());
let mut result = OMatrix::<T, Dyn, C>::zeros_generic(nrows, ncols);
$spmm_fn(T::zero(), &mut result, T::one(), Op::NoOp(lhs), Op::NoOp(rhs));
result
});
impl_spmm_cs_dense!(&'a $matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
lhs * &rhs
});
impl_spmm_cs_dense!($matrix_type_name<T>, &'a Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
&lhs * rhs
});
impl_spmm_cs_dense!($matrix_type_name<T>, Matrix<T, R, C, S>, $spmm_fn, |lhs, rhs| {
&lhs * &rhs
});
};
($sparse_matrix_type:ty, $dense_matrix_type:ty, $spmm_fn:ident,
|$lhs:ident, $rhs:ident| $body:tt) =>
{
impl<'a, T, R, C, S> Mul<$dense_matrix_type> for $sparse_matrix_type
where
T: Scalar + ClosedMulAssign + ClosedAddAssign + ClosedSubAssign + ClosedDivAssign + Neg + Zero + One,
R: Dim,
C: Dim,
S: RawStorage<T, R, C>,
DefaultAllocator: Allocator<Dyn, C>,
ShapeConstraint:
DimEq<U1, <<DefaultAllocator as Allocator<Dyn, C>>::Buffer<T> as RawStorage<T, Dyn, C>>::RStride>
+ DimEq<C, Dyn>
+ DimEq<Dyn, <<DefaultAllocator as Allocator<Dyn, C>>::Buffer<T> as RawStorage<T, Dyn, C>>::CStride>
+ DimEq<U1, S::RStride>
+ DimEq<R, Dyn>
+ DimEq<Dyn, S::CStride>
{
type Output = OMatrix<T, Dyn, C>;
fn mul(self, rhs: $dense_matrix_type) -> Self::Output {
let $lhs = self;
let $rhs = rhs;
$body
}
}
}
}
impl_spmm_cs_dense!(CsrMatrix, spmm_csr_dense);
impl_spmm_cs_dense!(CscMatrix, spmm_csc_dense);