use crate::linalg_helpers::DiagDMatrix;
use nalgebra::{ClosedMul, ComplexField, DMatrix, DVector, Scalar};
use std::ops::Mul;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Weights<ScalarType>
where
ScalarType: Scalar + ComplexField,
{
Unit,
Diagonal(DiagDMatrix<ScalarType>),
}
impl<ScalarType> Weights<ScalarType>
where
ScalarType: Scalar + ComplexField,
{
pub fn diagonal<VectorType>(diagonal: VectorType) -> Self
where
DVector<ScalarType>: From<VectorType>,
{
Self::from(DiagDMatrix::from(diagonal))
}
pub fn is_size_correct_for_data_length(&self, data_len: usize) -> bool {
match self {
Weights::Unit => true,
Weights::Diagonal(diag) => diag.size() == data_len,
}
}
}
impl<ScalarType> Default for Weights<ScalarType>
where
ScalarType: Scalar + ComplexField,
{
fn default() -> Self {
Self::Unit
}
}
impl<ScalarType> From<DiagDMatrix<ScalarType>> for Weights<ScalarType>
where
ScalarType: Scalar + ComplexField,
{
fn from(diag: DiagDMatrix<ScalarType>) -> Self {
Self::Diagonal(diag)
}
}
#[allow(non_snake_case)]
impl<ScalarType> Mul<DMatrix<ScalarType>> for &Weights<ScalarType>
where
ScalarType: ClosedMul + Scalar + ComplexField,
{
type Output = DMatrix<ScalarType>;
fn mul(self, rhs: DMatrix<ScalarType>) -> Self::Output {
match self {
Weights::Unit => rhs,
Weights::Diagonal(W) => W * &rhs,
}
}
}
#[allow(non_snake_case)]
impl<ScalarType> Mul<DVector<ScalarType>> for &Weights<ScalarType>
where
ScalarType: ClosedMul + Scalar + ComplexField,
{
type Output = DVector<ScalarType>;
fn mul(self, rhs: DVector<ScalarType>) -> Self::Output {
match self {
Weights::Unit => rhs,
Weights::Diagonal(W) => W * &rhs,
}
}
}
#[cfg(test)]
mod test {
use crate::solvers::levmar::weights::Weights;
use nalgebra::{DMatrix, DVector};
#[test]
#[allow(non_snake_case)]
fn unit_weight_produce_correct_results_when_multiplied_to_matrix_or_vector() {
let W = Weights::default();
let v = DVector::from(vec![1., 3., 3., 7.]);
let A = DMatrix::from_element(4, 4, 2.0);
assert_eq!(&W * v.clone(), v);
assert_eq!(&W * A.clone(), A);
}
#[test]
#[allow(non_snake_case)]
fn diagonal_weights_produce_correct_results_when_multiplied_to_matrix_or_vector() {
let diagonal = DVector::from(vec![3., 78., 6., 5.]);
let D = DMatrix::from_diagonal(&diagonal);
let W = Weights::diagonal(diagonal);
let v = DVector::from(vec![1., 3., 3., 7.]);
let mut A = DMatrix::from_element(4, 2, 0.);
A.set_column(0, &DVector::from(vec![32., 5., 86., 51.]));
A.set_column(1, &DVector::from(vec![65., 46., 8., 85.]));
assert_eq!(&D * &v, &W * v);
assert_eq!(&D * &A, &W * A);
}
}