rnn/normalization/
layer_norm.rs1use super::NormError;
2
3pub fn layer_norm_in_place(
4 x: &mut [f32],
5 gamma: &[f32],
6 beta: &[f32],
7 eps: f32,
8) -> Result<(), NormError> {
9 if x.is_empty() {
10 return Err(NormError::Empty);
11 }
12 if gamma.len() != x.len() || beta.len() != x.len() {
13 return Err(NormError::ShapeMismatch);
14 }
15 if !eps.is_finite() || eps <= 0.0 {
16 return Err(NormError::InvalidEps);
17 }
18
19 let mean = x.iter().copied().sum::<f32>() / x.len() as f32;
20 let mut var = 0.0f32;
21 for v in x.iter() {
22 let d = *v - mean;
23 var += d * d;
24 }
25 var /= x.len() as f32;
26 let inv_std = 1.0 / crate::math::sqrtf(var + eps);
27
28 for i in 0..x.len() {
29 let n = (x[i] - mean) * inv_std;
30 x[i] = n * gamma[i] + beta[i];
31 }
32
33 Ok(())
34}
35
36pub fn layer_norm(
37 input: &[f32],
38 gamma: &[f32],
39 beta: &[f32],
40 eps: f32,
41 out: &mut [f32],
42) -> Result<(), NormError> {
43 if out.len() != input.len() {
44 return Err(NormError::ShapeMismatch);
45 }
46 out.copy_from_slice(input);
47 layer_norm_in_place(out, gamma, beta, eps)
48}