use crate::util::DiagMatrix;
use nalgebra::{ComplexField, DefaultAllocator, Dim, Matrix, OVector, RawStorageMut, Scalar};
use std::ops::Mul;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Weights<ScalarType, D>
where
ScalarType: Scalar + ComplexField,
D: Dim,
DefaultAllocator: nalgebra::allocator::Allocator<D>,
{
Unit,
Diagonal(DiagMatrix<ScalarType, D>),
}
impl<ScalarType, D> Weights<ScalarType, D>
where
ScalarType: Scalar + ComplexField,
D: Dim,
DefaultAllocator: nalgebra::allocator::Allocator<D>,
{
pub fn diagonal(diagonal: OVector<ScalarType, D>) -> Self {
Self::from(DiagMatrix::from(diagonal))
}
pub fn is_size_correct_for_data_length(&self, data_len: usize) -> bool {
match self {
Weights::Unit => true,
Weights::Diagonal(diag) => diag.size() == data_len,
}
}
}
impl<ScalarType, D> Default for Weights<ScalarType, D>
where
ScalarType: Scalar + ComplexField,
D: Dim,
DefaultAllocator: nalgebra::allocator::Allocator<D>,
{
fn default() -> Self {
Self::Unit
}
}
impl<ScalarType, D> From<DiagMatrix<ScalarType, D>> for Weights<ScalarType, D>
where
ScalarType: Scalar + ComplexField,
D: Dim,
DefaultAllocator: nalgebra::allocator::Allocator<D>,
{
fn from(diag: DiagMatrix<ScalarType, D>) -> Self {
Self::Diagonal(diag)
}
}
#[allow(non_snake_case)]
impl<ScalarType, R, C, S> Mul<Matrix<ScalarType, R, C, S>> for &Weights<ScalarType, R>
where
ScalarType: Mul<ScalarType, Output = ScalarType> + Scalar + ComplexField,
C: Dim,
R: Dim,
S: RawStorageMut<ScalarType, R, C>,
DefaultAllocator: nalgebra::allocator::Allocator<R>,
DefaultAllocator: nalgebra::allocator::Allocator<R, C>,
{
type Output = Matrix<ScalarType, R, C, S>;
fn mul(self, rhs: Matrix<ScalarType, R, C, S>) -> Self::Output {
match self {
Weights::Unit => rhs,
Weights::Diagonal(W) => W * rhs,
}
}
}
#[cfg(any(test, doctest))]
mod test {
use crate::util::weights::Weights;
use nalgebra::{DMatrix, DVector};
#[test]
#[allow(non_snake_case)]
fn unit_weight_produce_correct_results_when_multiplied_to_matrix_or_vector() {
let W = Weights::default();
let v = DVector::from(vec![1., 3., 3., 7.]);
let A = DMatrix::from_element(4, 4, 2.0);
assert_eq!(&W * v.clone(), v);
assert_eq!(&W * A.clone(), A);
}
#[test]
#[allow(non_snake_case)]
fn diagonal_weights_produce_correct_results_when_multiplied_to_matrix_or_vector() {
let diagonal = DVector::from(vec![3., 78., 6., 5.]);
let D = DMatrix::from_diagonal(&diagonal);
let W = Weights::diagonal(diagonal);
let v = DVector::from(vec![1., 3., 3., 7.]);
let mut A = DMatrix::from_element(4, 2, 0.);
A.set_column(0, &DVector::from(vec![32., 5., 86., 51.]));
A.set_column(1, &DVector::from(vec![65., 46., 8., 85.]));
assert_eq!(&D * &v, &W * v);
assert_eq!(&D * &A, &W * A);
}
}