concision_traits/
norm.rs

1/*
2    Appellation: norm <module>
3    Contrib: @FL03
4*/
5/// a trait for computing the L1 norm of a tensor or array
6pub trait L1Norm {
7    type Output;
8
9    fn l1_norm(&self) -> Self::Output;
10}
11/// a trait for computing the L2 norm of a tensor or array
12pub trait L2Norm {
13    type Output;
14    /// compute the L2 norm of the tensor or array
15    fn l2_norm(&self) -> Self::Output;
16}
17
18/// The [Norm] trait serves as a unified interface for various normalization routnines. At the
19/// moment, the trait provides L1 and L2 techniques.
20pub trait Norm {
21    type Output;
22    /// compute the L1 norm of the tensor or array
23    fn l1_norm(&self) -> Self::Output;
24    /// compute the L2 norm of the tensor or array
25    fn l2_norm(&self) -> Self::Output;
26}
27
28/*
29 ************* Implementations *************
30*/
31impl<U, V> Norm for U
32where
33    U: L1Norm<Output = V> + L2Norm<Output = V>,
34{
35    type Output = V;
36
37    fn l1_norm(&self) -> Self::Output {
38        <Self as L1Norm>::l1_norm(self)
39    }
40
41    fn l2_norm(&self) -> Self::Output {
42        <Self as L2Norm>::l2_norm(self)
43    }
44}
45
46macro_rules! impl_norm {
47    ($trait:ident::$method:ident($($param:ident: $type:ty),*) => $self:ident$(.$call:ident())*) => {
48        impl<A, S, D> $trait for ndarray::ArrayBase<S, D, A>
49        where
50            A: 'static + Clone + num_traits::Float,
51            D: ndarray::Dimension,
52            S: ndarray::Data<Elem = A>,
53        {
54            type Output = A;
55
56            fn $method(&self, $($param: $type),*) -> Self::Output {
57                self$(.$call())*
58            }
59        }
60
61        impl<'a, A, S, D> $trait for &'a ndarray::ArrayBase<S, D, A>
62        where
63            A: 'static + Clone + num_traits::Float,
64            D: ndarray::Dimension,
65            S: ndarray::Data<Elem = A>,
66        {
67            type Output = A;
68
69            fn $method(&self, $($param: $type),*) -> Self::Output {
70                self$(.$call())*
71            }
72        }
73
74        impl<'a, A, S, D> $trait for &'a mut ndarray::ArrayBase<S, D, A>
75        where
76            A: 'static + Clone + num_traits::Float,
77            D: ndarray::Dimension,
78            S: ndarray::Data<Elem = A>,
79        {
80            type Output = A;
81
82            fn $method(&self, $($param: $type),*) -> Self::Output {
83                self$(.$call())*
84            }
85        }
86    };
87}
88
89impl_norm! { L2Norm::l2_norm() => self.pow2().sum().sqrt() }
90
91impl_norm! { L1Norm::l1_norm() => self.abs().sum() }