concision_core/traits/
norm.rs

1/*
2    Appellation: norm <module>
3    Contrib: @FL03
4*/
5/// a trait for computing the L1 norm of a tensor or array
6pub trait L1Norm {
7    type Output;
8
9    fn l1_norm(&self) -> Self::Output;
10}
11/// a trait for computing the L2 norm of a tensor or array
12pub trait L2Norm {
13    type Output;
14    /// compute the L2 norm of the tensor or array
15    fn l2_norm(&self) -> Self::Output;
16}
17
18/// The [Norm] trait serves as a unified interface for various normalization routnines. At the
19/// moment, the trait provides L1 and L2 techniques.
20pub trait Norm {
21    type Output;
22    /// compute the L1 norm of the tensor or array
23    fn l1_norm(&self) -> Self::Output;
24    /// compute the L2 norm of the tensor or array
25    fn l2_norm(&self) -> Self::Output;
26}
27
28/*
29 ************* Implementations *************
30*/
31use ndarray::{ArrayBase, Data, Dimension, ScalarOperand};
32use num_traits::Float;
33
34impl<U, V> Norm for U
35where
36    U: L1Norm<Output = V> + L2Norm<Output = V>,
37{
38    type Output = V;
39
40    fn l1_norm(&self) -> Self::Output {
41        <Self as L1Norm>::l1_norm(self)
42    }
43
44    fn l2_norm(&self) -> Self::Output {
45        <Self as L2Norm>::l2_norm(self)
46    }
47}
48
49macro_rules! impl_norm {
50    ($trait:ident::$method:ident($($param:ident: $type:ty),*) => $self:ident$(.$call:ident())*) => {
51        impl<A, S, D> $trait for ArrayBase<S, D>
52        where
53            A: Float + ScalarOperand,
54            D: Dimension,
55            S: Data<Elem = A>,
56        {
57            type Output = A;
58
59            fn $method(&self, $($param: $type),*) -> Self::Output {
60                self$(.$call())*
61            }
62        }
63
64        impl<'a, A, S, D> $trait for &'a ArrayBase<S, D>
65        where
66            A: Float + ScalarOperand,
67            D: Dimension,
68            S: Data<Elem = A>,
69        {
70            type Output = A;
71
72            fn $method(&self, $($param: $type),*) -> Self::Output {
73                self$(.$call())*
74            }
75        }
76
77        impl<'a, A, S, D> $trait for &'a mut ArrayBase<S, D>
78        where
79            A: Float + ScalarOperand,
80            D: Dimension,
81            S: Data<Elem = A>,
82        {
83            type Output = A;
84
85            fn $method(&self, $($param: $type),*) -> Self::Output {
86                self$(.$call())*
87            }
88        }
89    };
90}
91
92impl_norm! { L2Norm::l2_norm() => self.pow2().sum().sqrt() }
93
94impl_norm! { L1Norm::l1_norm() => self.abs().sum() }