Skip to main content

rnn/normalization/
layer_norm.rs

1use super::NormError;
2
3pub fn layer_norm_in_place(
4    x: &mut [f32],
5    gamma: &[f32],
6    beta: &[f32],
7    eps: f32,
8) -> Result<(), NormError> {
9    if x.is_empty() {
10        return Err(NormError::Empty);
11    }
12    if gamma.len() != x.len() || beta.len() != x.len() {
13        return Err(NormError::ShapeMismatch);
14    }
15    if !eps.is_finite() || eps <= 0.0 {
16        return Err(NormError::InvalidEps);
17    }
18
19    let mean = x.iter().copied().sum::<f32>() / x.len() as f32;
20    let mut var = 0.0f32;
21    for v in x.iter() {
22        let d = *v - mean;
23        var += d * d;
24    }
25    var /= x.len() as f32;
26    let inv_std = 1.0 / crate::math::sqrtf(var + eps);
27
28    for i in 0..x.len() {
29        let n = (x[i] - mean) * inv_std;
30        x[i] = n * gamma[i] + beta[i];
31    }
32
33    Ok(())
34}
35
36pub fn layer_norm(
37    input: &[f32],
38    gamma: &[f32],
39    beta: &[f32],
40    eps: f32,
41    out: &mut [f32],
42) -> Result<(), NormError> {
43    if out.len() != input.len() {
44        return Err(NormError::ShapeMismatch);
45    }
46    out.copy_from_slice(input);
47    layer_norm_in_place(out, gamma, beta, eps)
48}