#[cfg(any(test, doctest))]
mod test;
use nalgebra::U1;
use nalgebra::{
ComplexField, DVector, DefaultAllocator, Dim, Dyn, Matrix, OMatrix, OVector, RawStorageMut,
Scalar,
};
use std::ops::Mul;
mod weights;
pub use weights::Weights;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DiagMatrix<ScalarType, D>
where
ScalarType: Scalar + ComplexField,
D: Dim,
DefaultAllocator: nalgebra::allocator::Allocator<D>,
{
diagonal: OVector<ScalarType, D>,
}
impl<ScalarType, D> DiagMatrix<ScalarType, D>
where
ScalarType: Scalar + ComplexField,
D: Dim,
DefaultAllocator: nalgebra::allocator::Allocator<D>,
{
pub fn ncols(&self) -> usize {
self.size()
}
pub fn nrows(&self) -> usize {
self.size()
}
pub fn size(&self) -> usize {
self.diagonal.len()
}
pub fn from_real_field(diagonal: OVector<ScalarType::RealField, D>) -> Self
where
DefaultAllocator: nalgebra::allocator::Allocator<D>,
{
Self::from(diagonal.map(ScalarType::from_real))
}
}
impl<ScalarType, D> From<OVector<ScalarType, D>> for DiagMatrix<ScalarType, D>
where
ScalarType: Scalar + ComplexField,
D: Dim,
DefaultAllocator: nalgebra::allocator::Allocator<D>,
{
fn from(diagonal: OVector<ScalarType, D>) -> Self {
Self { diagonal }
}
}
impl<ScalarType, R, C, S> Mul<Matrix<ScalarType, R, C, S>> for &DiagMatrix<ScalarType, R>
where
ScalarType: Mul<ScalarType, Output = ScalarType> + Scalar + ComplexField,
C: Dim,
R: Dim,
S: RawStorageMut<ScalarType, R, C>,
DefaultAllocator: nalgebra::allocator::Allocator<R>,
{
type Output = Matrix<ScalarType, R, C, S>;
fn mul(self, mut rhs: Matrix<ScalarType, R, C, S>) -> Self::Output {
assert_eq!(
self.ncols(),
rhs.nrows(),
"Matrix dimensions incorrect for diagonal matrix multiplication."
);
rhs.column_iter_mut()
.for_each(|mut col| col.component_mul_assign(&self.diagonal));
rhs
}
}
#[inline]
pub(crate) fn to_vector<T: Scalar + std::fmt::Debug + Clone>(
mat: OMatrix<T, Dyn, Dyn>,
) -> DVector<T> {
let new_rows = Dyn(mat.nrows() * mat.ncols());
mat.reshape_generic(new_rows, U1)
}