candle_nn/
layer_norm.rs

1//! Layer Normalization.
2//!
3//! This layer applies Layer Normalization over a mini-batch of inputs as described in [`Layer
4//! Normalization`]. The input is expected to have three dimensions: a batch dimension, a length,
5//! and a hidden size, the normalization is applied over the last dimension.
6//!
7//! # Example
8//!
9//! ```rust
10//! use candle::{Tensor, Device::Cpu, test_utils::to_vec3_round};
11//! use candle_nn::{LayerNorm, Module};
12//! # fn main() -> candle::Result<()> {
13//!
14//! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?;
15//! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?;
16//! let layer = LayerNorm::new(w, b, 1e-5);
17//!
18//! let xs = Tensor::new(
19//!     &[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]],
20//!     &Cpu)?;
21//! let ys = layer.forward(&xs)?;
22//! assert_eq!(
23//!     to_vec3_round(&ys, 4)?,
24//!     &[[[-1.2247, 0.0,  1.2247],
25//!        [-1.2247, 0.0,  1.2247],
26//!        [ 1.2247, 0.0, -1.2247]]]);
27//! # Ok(()) }
28//! ```
29//!
30//! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450
31use candle::{DType, Module, Result, Tensor, D};
32
33#[derive(Debug, Clone, Copy, PartialEq)]
34pub struct LayerNormConfig {
35    pub eps: f64,
36    /// Whether to remove the mean or not, the default is true and when set to false, this turns
37    /// this layer into RmsNorm.
38    pub remove_mean: bool,
39    pub affine: bool,
40}
41
42impl Default for LayerNormConfig {
43    fn default() -> Self {
44        Self {
45            eps: 1e-5,
46            remove_mean: true,
47            affine: true,
48        }
49    }
50}
51
52impl From<f64> for LayerNormConfig {
53    fn from(eps: f64) -> Self {
54        Self {
55            eps,
56            remove_mean: true,
57            affine: true,
58        }
59    }
60}
61
62// This layer norm version handles both weight and bias so removes the mean.
63#[derive(Clone, Debug)]
64pub struct LayerNorm {
65    weight: Tensor,
66    bias: Option<Tensor>,
67    remove_mean: bool,
68    eps: f64,
69}
70
71impl LayerNorm {
72    pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
73        Self {
74            weight,
75            bias: Some(bias),
76            remove_mean: true,
77            eps,
78        }
79    }
80
81    pub fn new_no_bias(weight: Tensor, eps: f64) -> Self {
82        Self {
83            weight,
84            bias: None,
85            remove_mean: true,
86            eps,
87        }
88    }
89
90    pub fn rms_norm(weight: Tensor, eps: f64) -> Self {
91        Self {
92            weight,
93            bias: None,
94            remove_mean: false,
95            eps,
96        }
97    }
98
99    pub fn weight(&self) -> &Tensor {
100        &self.weight
101    }
102
103    pub fn bias(&self) -> Option<&Tensor> {
104        self.bias.as_ref()
105    }
106}
107
108impl Module for LayerNorm {
109    fn forward(&self, x: &Tensor) -> Result<Tensor> {
110        if x.is_contiguous() && self.remove_mean {
111            if let Some(bias) = self.bias.as_ref() {
112                return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
113            }
114        }
115        let x_dtype = x.dtype();
116        let internal_dtype = match x_dtype {
117            DType::F16 | DType::BF16 => DType::F32,
118            d => d,
119        };
120        let hidden_size = x.dim(D::Minus1)?;
121        let x = x.to_dtype(internal_dtype)?;
122        let x = if self.remove_mean {
123            let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
124            x.broadcast_sub(&mean_x)?
125        } else {
126            x
127        };
128        let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
129        let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
130        let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
131        match &self.bias {
132            None => Ok(x),
133            Some(bias) => x.broadcast_add(bias),
134        }
135    }
136}
137
138pub fn layer_norm<C: Into<LayerNormConfig>>(
139    size: usize,
140    config: C,
141    vb: crate::VarBuilder,
142) -> Result<LayerNorm> {
143    let config = config.into();
144    let weight = vb.get_with_hints(size, "weight", crate::Init::Const(1.))?;
145    let bias = if config.affine {
146        Some(vb.get_with_hints(size, "bias", crate::Init::Const(0.))?)
147    } else {
148        None
149    };
150    Ok(LayerNorm {
151        weight,
152        bias,
153        remove_mean: config.remove_mean,
154        eps: config.eps,
155    })
156}
157
158pub fn layer_norm_no_bias(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
159    let config = LayerNormConfig {
160        eps,
161        remove_mean: true,
162        affine: false,
163    };
164    layer_norm(size, config, vb)
165}
166
167/// RmsNorm is a specialized version of the LayerNorm module.
168#[derive(Clone, Debug)]
169pub struct RmsNorm(LayerNorm);
170
171impl RmsNorm {
172    pub fn new(weight: Tensor, eps: f64) -> Self {
173        Self(LayerNorm::rms_norm(weight, eps))
174    }
175
176    pub fn into_inner(self) -> LayerNorm {
177        self.0
178    }
179
180    /// Faster variant of the forward kernel, this can only be used on contiguous tensors though.
181    pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
182        self.0.forward(xs)
183    }
184}
185
186impl Module for RmsNorm {
187    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
188        if xs.is_contiguous() {
189            crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
190        } else {
191            self.0.forward(xs)
192        }
193    }
194}
195
196pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<RmsNorm> {
197    let config = LayerNormConfig {
198        eps,
199        remove_mean: false,
200        affine: false,
201    };
202    Ok(RmsNorm(layer_norm(size, config, vb)?))
203}