concision_linear/norm/layer/
model.rs

1/*
2    Appellation: layer <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::Config;
6use crate::{Biased, LinearParams, ParamMode, Unbiased};
7use concision::Forward;
8use nd::prelude::*;
9use nd::{Data, RemoveAxis};
10use num::traits::{Float, FromPrimitive, One, Zero};
11
12// #62
13///
14/// Layer Normalization directly estimates the normalization statistics from the summed inputs
15/// to the neurons within a _hidden_ layer, eliminating the need to introduce any additional dependencies.
16///
17/// [LayerNorm] follows the [Layer Normalization](https://arxiv.org/abs/1607.06450) paper.
18///
19/// ### Resources
20pub struct LayerNorm<A = f64, K = crate::Biased, D = Ix2>
21where
22    D: Dimension,
23{
24    config: Config<D>,
25    params: LinearParams<A, K, D>,
26}
27
28macro_rules! impl_norm_builder {
29    ($method:ident$(.$call:ident)? where $($rest:tt)*) => {
30        impl_norm_builder!(@impl $method$(.$call)? where $($rest)*);
31    };
32    (@impl $method:ident where $($rest:tt)*) => {
33        impl_norm_builder!(@impl $method.$method where $($rest)*);
34    };
35    (@impl $method:ident.$call:ident where $($rest:tt)*) => {
36        pub fn $method<Sh>(shape: Sh) -> Self
37        where
38            Sh: ShapeBuilder<Dim = D>,
39            $($rest)*
40        {
41            Self::from_params(LinearParams::<A, K, D>::$call(shape))
42        }
43    };
44}
45
46impl<A, K, D> LayerNorm<A, K, D>
47where
48    D: RemoveAxis,
49    K: ParamMode,
50{
51    pub fn from_config(config: Config<D>) -> Self
52    where
53        A: Default,
54    {
55        let params = LinearParams::<A, K, D>::new(config.dim());
56        Self { config, params }
57    }
58
59    pub fn from_elem<Sh>(shape: Sh, elem: A) -> Self
60    where
61        A: Clone,
62        Sh: ShapeBuilder<Dim = D>,
63    {
64        let dim = shape.into_shape().raw_dim().clone();
65        let config = Config::new().dim(dim.clone()).build();
66        let params = LinearParams::<A, K, D>::from_elem(dim, elem);
67        Self { config, params }
68    }
69
70    pub fn from_params(params: LinearParams<A, K, D>) -> Self {
71        let config = Config::new().dim(params.raw_dim()).build();
72        Self { config, params }
73    }
74
75    impl_norm_builder!(new where A: Default);
76    impl_norm_builder!(ones where A: Clone + One);
77    impl_norm_builder!(zeros where A: Clone + Zero);
78
79    pub const fn config(&self) -> &Config<D> {
80        &self.config
81    }
82
83    pub fn is_biased(&self) -> bool {
84        self.params().is_biased()
85    }
86    /// Returns an immutable reference to the layer's parameters.
87    pub const fn params(&self) -> &LinearParams<A, K, D> {
88        &self.params
89    }
90    /// Returns a mutable reference to the layer's parameters.
91    pub fn params_mut(&mut self) -> &mut LinearParams<A, K, D> {
92        &mut self.params
93    }
94
95    pub fn dim(&self) -> D::Pattern {
96        self.config().dim()
97    }
98
99    pub fn eps(&self) -> f64 {
100        self.config().eps()
101    }
102
103    pub fn ndim(&self) -> usize {
104        self.config().ndim()
105    }
106
107    pub fn raw_dim(&self) -> D {
108        self.config().raw_dim()
109    }
110
111    pub fn shape(&self) -> &[usize] {
112        self.config().shape()
113    }
114}
115
116impl<A, D> Default for LayerNorm<A, Biased, D>
117where
118    A: Default,
119    D: RemoveAxis,
120{
121    fn default() -> Self {
122        Self {
123            config: Config::default(),
124            params: Default::default(),
125        }
126    }
127}
128
129impl<A, D> Default for LayerNorm<A, Unbiased, D>
130where
131    A: Default,
132    D: RemoveAxis,
133{
134    fn default() -> Self {
135        Self {
136            config: Config::default(),
137            params: Default::default(),
138        }
139    }
140}
141
142impl<A, S, D> Forward<ArrayBase<S, D>> for LayerNorm<A, Biased, D>
143where
144    A: Float + FromPrimitive,
145    D: RemoveAxis,
146    S: Data<Elem = A>,
147{
148    type Output = Array<A, D>;
149
150    fn forward(&self, x: &ArrayBase<S, D>) -> Self::Output {
151        let norm = if let Some(axis) = self.config().axis() {
152            super::layer_norm_axis(x, *axis, self.eps())
153        } else {
154            super::layer_norm(x, self.eps())
155        };
156        norm * self.params().weights() + self.params().bias()
157    }
158}
159
160impl<A, S, D> Forward<ArrayBase<S, D>> for LayerNorm<A, Unbiased, D>
161where
162    A: Float + FromPrimitive,
163    D: RemoveAxis,
164    S: Data<Elem = A>,
165{
166    type Output = Array<A, D>;
167
168    fn forward(&self, x: &ArrayBase<S, D>) -> Self::Output {
169        let norm = if let Some(axis) = self.config().axis() {
170            super::layer_norm_axis(x, *axis, self.eps())
171        } else {
172            super::layer_norm(x, self.eps())
173        };
174        norm * self.params().weights()
175    }
176}