concision_linear/norm/layer/
model.rs1use super::Config;
6use crate::{Biased, LinearParams, ParamMode, Unbiased};
7use concision::Forward;
8use nd::prelude::*;
9use nd::{Data, RemoveAxis};
10use num::traits::{Float, FromPrimitive, One, Zero};
11
12pub struct LayerNorm<A = f64, K = crate::Biased, D = Ix2>
21where
22 D: Dimension,
23{
24 config: Config<D>,
25 params: LinearParams<A, K, D>,
26}
27
28macro_rules! impl_norm_builder {
29 ($method:ident$(.$call:ident)? where $($rest:tt)*) => {
30 impl_norm_builder!(@impl $method$(.$call)? where $($rest)*);
31 };
32 (@impl $method:ident where $($rest:tt)*) => {
33 impl_norm_builder!(@impl $method.$method where $($rest)*);
34 };
35 (@impl $method:ident.$call:ident where $($rest:tt)*) => {
36 pub fn $method<Sh>(shape: Sh) -> Self
37 where
38 Sh: ShapeBuilder<Dim = D>,
39 $($rest)*
40 {
41 Self::from_params(LinearParams::<A, K, D>::$call(shape))
42 }
43 };
44}
45
46impl<A, K, D> LayerNorm<A, K, D>
47where
48 D: RemoveAxis,
49 K: ParamMode,
50{
51 pub fn from_config(config: Config<D>) -> Self
52 where
53 A: Default,
54 {
55 let params = LinearParams::<A, K, D>::new(config.dim());
56 Self { config, params }
57 }
58
59 pub fn from_elem<Sh>(shape: Sh, elem: A) -> Self
60 where
61 A: Clone,
62 Sh: ShapeBuilder<Dim = D>,
63 {
64 let dim = shape.into_shape().raw_dim().clone();
65 let config = Config::new().dim(dim.clone()).build();
66 let params = LinearParams::<A, K, D>::from_elem(dim, elem);
67 Self { config, params }
68 }
69
70 pub fn from_params(params: LinearParams<A, K, D>) -> Self {
71 let config = Config::new().dim(params.raw_dim()).build();
72 Self { config, params }
73 }
74
75 impl_norm_builder!(new where A: Default);
76 impl_norm_builder!(ones where A: Clone + One);
77 impl_norm_builder!(zeros where A: Clone + Zero);
78
79 pub const fn config(&self) -> &Config<D> {
80 &self.config
81 }
82
83 pub fn is_biased(&self) -> bool {
84 self.params().is_biased()
85 }
86 pub const fn params(&self) -> &LinearParams<A, K, D> {
88 &self.params
89 }
90 pub fn params_mut(&mut self) -> &mut LinearParams<A, K, D> {
92 &mut self.params
93 }
94
95 pub fn dim(&self) -> D::Pattern {
96 self.config().dim()
97 }
98
99 pub fn eps(&self) -> f64 {
100 self.config().eps()
101 }
102
103 pub fn ndim(&self) -> usize {
104 self.config().ndim()
105 }
106
107 pub fn raw_dim(&self) -> D {
108 self.config().raw_dim()
109 }
110
111 pub fn shape(&self) -> &[usize] {
112 self.config().shape()
113 }
114}
115
116impl<A, D> Default for LayerNorm<A, Biased, D>
117where
118 A: Default,
119 D: RemoveAxis,
120{
121 fn default() -> Self {
122 Self {
123 config: Config::default(),
124 params: Default::default(),
125 }
126 }
127}
128
129impl<A, D> Default for LayerNorm<A, Unbiased, D>
130where
131 A: Default,
132 D: RemoveAxis,
133{
134 fn default() -> Self {
135 Self {
136 config: Config::default(),
137 params: Default::default(),
138 }
139 }
140}
141
142impl<A, S, D> Forward<ArrayBase<S, D>> for LayerNorm<A, Biased, D>
143where
144 A: Float + FromPrimitive,
145 D: RemoveAxis,
146 S: Data<Elem = A>,
147{
148 type Output = Array<A, D>;
149
150 fn forward(&self, x: &ArrayBase<S, D>) -> Self::Output {
151 let norm = if let Some(axis) = self.config().axis() {
152 super::layer_norm_axis(x, *axis, self.eps())
153 } else {
154 super::layer_norm(x, self.eps())
155 };
156 norm * self.params().weights() + self.params().bias()
157 }
158}
159
160impl<A, S, D> Forward<ArrayBase<S, D>> for LayerNorm<A, Unbiased, D>
161where
162 A: Float + FromPrimitive,
163 D: RemoveAxis,
164 S: Data<Elem = A>,
165{
166 type Output = Array<A, D>;
167
168 fn forward(&self, x: &ArrayBase<S, D>) -> Self::Output {
169 let norm = if let Some(axis) = self.config().axis() {
170 super::layer_norm_axis(x, *axis, self.eps())
171 } else {
172 super::layer_norm(x, self.eps())
173 };
174 norm * self.params().weights()
175 }
176}