concision_utils/utils/
norm.rs

1/*
2    Appellation: norm <module>
3    Contrib: @FL03
4*/
5use ndarray::{Array, ArrayBase, Axis, Data, Dimension, RemoveAxis};
6use num_traits::{Float, FromPrimitive};
7
8pub fn layer_norm<A, S, D>(x: &ArrayBase<S, D>, eps: f64) -> Array<A, D>
9where
10    A: Float + FromPrimitive,
11    D: Dimension,
12    S: Data<Elem = A>,
13{
14    let mean = x.mean().unwrap();
15    let denom = {
16        let eps = A::from(eps).unwrap();
17        let var = x.var(A::zero());
18        (var + eps).sqrt()
19    };
20    x.mapv(|xi| (xi - mean) / denom)
21}
22
23pub fn layer_norm_axis<A, S, D>(x: &ArrayBase<S, D>, axis: Axis, eps: f64) -> Array<A, D>
24where
25    A: Float + FromPrimitive,
26    D: RemoveAxis,
27    S: Data<Elem = A>,
28{
29    let eps = A::from(eps).unwrap();
30    let mean = x.mean_axis(axis).unwrap();
31    let var = x.var_axis(axis, A::zero());
32    let inv_std = var.mapv(|v| (v + eps).recip().sqrt());
33
34    (x - &mean) * &inv_std
35}