concision_linear/norm/layer/
mod.rs

1/*
2    Appellation: layer <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5//! # Layer Normalization
6//!
7//! This module provides the necessary tools for creating and training layer normalization layers.
8pub(crate) use self::utils::*;
9pub use self::{config::*, model::*};
10
11pub(crate) mod config;
12pub(crate) mod model;
13
14pub const EPSILON: f64 = 1e-5;
15
16pub(crate) mod prelude {
17    pub use super::config::Config as LayerNormConfig;
18    pub use super::model::LayerNorm;
19}
20
21pub(crate) mod utils {
22    use nd::prelude::*;
23    use nd::{Data, RemoveAxis};
24    use num::traits::{Float, FromPrimitive};
25
26    pub(crate) fn layer_norm<A, S, D>(x: &ArrayBase<S, D>, eps: f64) -> Array<A, D>
27    where
28        A: Float + FromPrimitive,
29        D: Dimension,
30        S: Data<Elem = A>,
31    {
32        let mean = x.mean().unwrap();
33        let denom = {
34            let eps = A::from(eps).unwrap();
35            let var = x.var(A::zero());
36            (var + eps).sqrt()
37        };
38        x.mapv(|xi| (xi - mean) / denom)
39    }
40
41    pub(crate) fn layer_norm_axis<A, S, D>(x: &ArrayBase<S, D>, axis: Axis, eps: f64) -> Array<A, D>
42    where
43        A: Float + FromPrimitive,
44        D: RemoveAxis,
45        S: Data<Elem = A>,
46    {
47        let eps = A::from(eps).unwrap();
48        let mean = x.mean_axis(axis).unwrap();
49        let var = x.var_axis(axis, A::zero());
50        let inv_std = var.mapv(|v| (v + eps).recip().sqrt());
51        let x_norm = (x - &mean) * &inv_std;
52        x_norm
53    }
54}