use num_traits::float::FloatCore;
pub trait NormL1 {
type Output;
fn norm_l1(self) -> Self::Output;
}
pub trait NormL2 {
type Output;
fn norm_l2_sqr(self) -> Self::Output;
fn norm_l2(self) -> Self::Output;
}
pub trait NormLInf {
type Output;
fn norm_l_inf(self) -> Self::Output;
}
pub trait Normalize {
fn normalize(self) -> Self;
}
pub trait Dot<V = Self> {
type Output;
fn dot(self, other: V) -> Self::Output;
}
pub trait Outer<V> {
type Output;
fn outer(self, other: V) -> Self::Output;
}
pub trait Conj {
fn conj(self) -> Self;
}
pub trait Epsilon {
fn is_epsilon(&self) -> bool;
}
pub trait Broadcast<V> {
fn broadcast(self) -> V;
}
macro_rules! derive_primitive_base {
($T:ident) => {
impl Dot for $T {
type Output = Self;
fn dot(self, other: Self) -> Self {
self * other
}
}
};
}
macro_rules! derive_primitive_unsigned {
($T:ident) => {
derive_primitive_base!($T);
impl NormL1 for $T {
type Output = Self;
fn norm_l1(self) -> Self {
self
}
}
impl NormL2 for $T {
type Output = Self;
fn norm_l2(self) -> Self {
self
}
fn norm_l2_sqr(self) -> Self {
self * self
}
}
impl NormLInf for $T {
type Output = Self;
fn norm_l_inf(self) -> Self {
self
}
}
};
}
macro_rules! derive_primitive_signed {
($T:ident) => {
derive_primitive_base!($T);
impl NormL1 for $T {
type Output = Self;
fn norm_l1(self) -> Self {
self.abs()
}
}
impl NormL2 for $T {
type Output = Self;
fn norm_l2(self) -> Self {
self.abs()
}
fn norm_l2_sqr(self) -> Self {
self * self
}
}
impl NormLInf for $T {
type Output = Self;
fn norm_l_inf(self) -> Self {
self.abs()
}
}
};
}
macro_rules! derive_primitive_float {
($T:ident) => {
derive_primitive_base!($T);
impl NormL1 for $T {
type Output = Self;
fn norm_l1(self) -> Self {
<$T as FloatCore>::abs(self)
}
}
impl NormL2 for $T {
type Output = Self;
fn norm_l2(self) -> Self {
<$T as FloatCore>::abs(self)
}
fn norm_l2_sqr(self) -> Self {
self * self
}
}
impl NormLInf for $T {
type Output = Self;
fn norm_l_inf(self) -> Self {
<$T as FloatCore>::abs(self)
}
}
impl Conj for $T {
fn conj(self) -> Self {
self
}
}
impl Epsilon for $T {
fn is_epsilon(&self) -> bool {
self.abs() <= Self::EPSILON
}
}
};
}
derive_primitive_unsigned!(u8);
derive_primitive_unsigned!(u16);
derive_primitive_unsigned!(u32);
derive_primitive_unsigned!(u64);
derive_primitive_unsigned!(usize);
derive_primitive_signed!(i8);
derive_primitive_signed!(i16);
derive_primitive_signed!(i32);
derive_primitive_signed!(i64);
derive_primitive_signed!(isize);
derive_primitive_float!(f32);
derive_primitive_float!(f64);