use nalgebra::constraint::{
SameNumberOfColumns, SameNumberOfRows, ShapeConstraint,
};
use nalgebra::storage::Storage;
use nalgebra::{
Dim, EuclideanNorm, LpNorm, Matrix, Norm, SimdComplexField, SimdRealField,
UniformNorm, Vector,
};
use num_traits::Zero;
use serde::{Deserialize, Serialize};
#[derive(Copy, Clone, Debug)]
pub struct EuclideanNormSquared;
impl<X: SimdComplexField> Norm<X> for EuclideanNormSquared {
#[inline]
fn norm<R, C, S>(&self, m: &Matrix<X, R, C, S>) -> X::SimdRealField
where
R: Dim,
C: Dim,
S: Storage<X, R, C>,
{
m.norm_squared()
}
#[inline]
fn metric_distance<R1, C1, S1, R2, C2, S2>(
&self,
m1: &Matrix<X, R1, C1, S1>,
m2: &Matrix<X, R2, C2, S2>,
) -> X::SimdRealField
where
R1: Dim,
C1: Dim,
S1: Storage<X, R1, C1>,
R2: Dim,
C2: Dim,
S2: Storage<X, R2, C2>,
ShapeConstraint: SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C2>,
{
m1.zip_fold(m2, X::SimdRealField::zero(), |acc, a, b| {
let diff = a - b;
acc + diff.simd_modulus_squared()
})
}
}
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum NormCost {
OneNorm,
TwoNorm,
LpNorm(i8),
InfNorm,
}
impl NormCost {
pub fn cost<X, R1, S1, R2, S2>(
&self,
a: &Vector<X, R1, S1>,
b: &Vector<X, R2, S2>,
) -> X
where
X: SimdRealField,
R1: Dim,
S1: Storage<X, R1>,
R2: Dim,
S2: Storage<X, R2>,
ShapeConstraint: SameNumberOfRows<R1, R2>,
{
self.metric_distance(a, b)
}
}
impl Default for NormCost {
fn default() -> Self {
Self::TwoNorm
}
}
impl<X: SimdComplexField> Norm<X> for NormCost {
#[inline]
fn norm<R, C, S>(&self, m: &Matrix<X, R, C, S>) -> X::SimdRealField
where
R: Dim,
C: Dim,
S: Storage<X, R, C>,
{
match self {
Self::OneNorm => LpNorm(1).norm(m),
Self::TwoNorm => EuclideanNorm.norm(m),
Self::LpNorm(i) => LpNorm((*i).into()).norm(m),
Self::InfNorm => UniformNorm.norm(m),
}
}
#[inline]
fn metric_distance<R1, C1, S1, R2, C2, S2>(
&self,
m1: &Matrix<X, R1, C1, S1>,
m2: &Matrix<X, R2, C2, S2>,
) -> X::SimdRealField
where
R1: Dim,
C1: Dim,
S1: Storage<X, R1, C1>,
R2: Dim,
C2: Dim,
S2: Storage<X, R2, C2>,
ShapeConstraint: SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C2>,
{
match self {
Self::OneNorm => LpNorm(1).metric_distance(m1, m2),
Self::TwoNorm => EuclideanNorm.metric_distance(m1, m2),
Self::LpNorm(i) => LpNorm((*i).into()).metric_distance(m1, m2),
Self::InfNorm => UniformNorm.metric_distance(m1, m2),
}
}
}