burn_core/nn/norm/
rms.rs

1use burn_tensor::DType;
2
3use crate as burn;
4
5use crate::config::Config;
6use crate::module::Module;
7use crate::module::Param;
8use crate::module::{Content, DisplaySettings, ModuleDisplay};
9use crate::nn::Initializer;
10use crate::tensor::Tensor;
11use crate::tensor::backend::Backend;
12
13/// Configuration to create a [RMS Norm](RmsNorm) layer using the [init function](RmsNormConfig::init).
14#[derive(Config)]
15pub struct RmsNormConfig {
16    /// The size of the input features.
17    pub d_model: usize,
18    /// A value required for numerical stability. Default: 1e-5
19    #[config(default = 1e-5)]
20    pub epsilon: f64,
21}
22
23impl RmsNormConfig {
24    /// Initialize a new [RMS Norm](RmsNorm) module.
25    ///
26    /// # Panics
27    ///
28    /// Panics if `epsilon` is not positive.
29    pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
30        assert!(self.epsilon > 0.0, "epsilon must be positive.");
31
32        let gamma = Initializer::Ones.init([self.d_model], device);
33
34        RmsNorm {
35            gamma,
36            epsilon: self.epsilon,
37        }
38    }
39}
40
41/// Applies RMS Normalization over an input tensor along the last dimension.
42///
43/// `Y = X / sqrt(mean(X^2) + eps) * gamma`
44///
45/// Where:
46/// - `X` is the input tensor
47/// - `Y` is the output tensor
48/// - `gamma` is the learnable weight
49/// - `mean` is the mean operation
50/// - `eps` is a small value to avoid division by zero.
51///
52/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.
53#[derive(Module, Debug)]
54#[module(custom_display)]
55pub struct RmsNorm<B: Backend> {
56    /// The learnable parameter to scale the normalized tensor
57    pub gamma: Param<Tensor<B, 1>>,
58    /// A value required for numerical stability
59    pub epsilon: f64,
60}
61
62impl<B: Backend> RmsNorm<B> {
63    /// Applies the forward pass on the input tensor.
64    ///
65    /// See the [RmsNorm](RmsNorm) documentation for more information.
66    ///
67    /// # Shapes
68    ///
69    /// - input: `[..., any, d_model]`
70    /// - output: `[..., any, d_model]`
71    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
72        // Calculate the root-mean-square norm of the input tensor along the last dimension
73        let dtype = x.dtype();
74        let rms =
75            (x.clone().cast(DType::F32).powf_scalar(2.0).mean_dim(D - 1) + self.epsilon).sqrt();
76        (x / rms.cast(dtype)) * self.gamma.val().unsqueeze()
77    }
78}
79
80impl<B: Backend> ModuleDisplay for RmsNorm<B> {
81    fn custom_settings(&self) -> Option<DisplaySettings> {
82        DisplaySettings::new()
83            .with_new_line_after_attribute(false)
84            .optional()
85    }
86
87    fn custom_content(&self, content: Content) -> Option<Content> {
88        let [d_model] = self.gamma.shape().dims();
89        content
90            .add("d_model", &d_model)
91            .add("epsilon", &self.epsilon)
92            .optional()
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::TestBackend;
100    use crate::tensor::TensorData;
101    use alloc::format;
102    use burn_tensor::{Tolerance, ops::FloatElem};
103    type FT = FloatElem<TestBackend>;
104
105    #[test]
106    fn rms_norm_forward() {
107        let device = Default::default();
108        let module = RmsNormConfig::new(3)
109            .with_epsilon(1e-5)
110            .init::<TestBackend>(&device);
111
112        let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);
113
114        let output = module.forward(input);
115
116        let expected = TensorData::from([
117            [0.0000, 0.7746, 1.5492],
118            [0.7348, 0.9798, 1.2247],
119            [0.8514, 0.9933, 1.1352],
120        ]);
121        output
122            .to_data()
123            .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-4, 1e-4));
124    }
125
126    #[test]
127    fn display() {
128        let config = RmsNormConfig::new(6);
129        let layer_norm = config.init::<TestBackend>(&Default::default());
130
131        assert_eq!(
132            format!("{}", layer_norm),
133            "RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
134        );
135    }
136}