1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
use crate as burn;

use crate::config::Config;
use crate::module::Module;
use crate::module::Param;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
use crate::nn::Initializer;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;

/// Configuration to create a [RMS Norm](RmsNorm) layer using the [init function](RmsNormConfig::init).
#[derive(Config)]
pub struct RmsNormConfig {
    /// The size of the input features.
    pub d_model: usize,
    /// A value required for numerical stability. Default: 1e-5
    #[config(default = 1e-5)]
    pub epsilon: f64,
}

impl RmsNormConfig {
    /// Initialize a new [RMS Norm](RmsNorm) module.
    ///
    /// # Panics
    ///
    /// Panics if `epsilon` is not positive.
    pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
        assert!(self.epsilon > 0.0, "epsilon must be positive.");

        let gamma = Initializer::Ones.init([self.d_model], device);

        RmsNorm {
            gamma,
            epsilon: self.epsilon,
        }
    }
}

/// Applies RMS Normalization over an input tensor along the last dimension.
///
/// `Y = X / sqrt(mean(X^2) + eps) * gamma`
///
/// Where:
/// - `X` is the input tensor
/// - `Y` is the output tensor
/// - `gamma` is the learnable weight
/// - `mean` is the mean operation
/// - `eps` is a small value to avoid division by zero.
///
/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct RmsNorm<B: Backend> {
    /// The learnable parameter to scale the normalized tensor
    pub gamma: Param<Tensor<B, 1>>,
    /// A value required for numerical stability
    pub epsilon: f64,
}

impl<B: Backend> RmsNorm<B> {
    /// Applies the forward pass on the input tensor.
    ///
    /// See the [RmsNorm](RmsNorm) documentation for more information.
    ///
    /// # Shapes
    ///
    /// - input: `[..., any, d_model]`
    /// - output: `[..., any, d_model]`
    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
        // Calculate the root-mean-square norm of the input tensor along the last dimension
        let rms = (x
            .clone()
            .into_full_precision()
            .powf_scalar(2.0)
            .mean_dim(D - 1)
            + self.epsilon)
            .sqrt();
        (x / Tensor::from_full_precision(rms)) * self.gamma.val().unsqueeze()
    }
}

impl<B: Backend> ModuleDisplay for RmsNorm<B> {
    fn custom_settings(&self) -> Option<DisplaySettings> {
        DisplaySettings::new()
            .with_new_line_after_attribute(false)
            .optional()
    }

    fn custom_content(&self, content: Content) -> Option<Content> {
        let [d_model] = self.gamma.shape().dims;
        content
            .add("d_model", &d_model)
            .add("epsilon", &self.epsilon)
            .optional()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::tensor::TensorData;
    use crate::TestBackend;
    use alloc::format;

    #[test]
    fn rms_norm_forward() {
        let device = Default::default();
        let module = RmsNormConfig::new(3)
            .with_epsilon(1e-5)
            .init::<TestBackend>(&device);

        let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);

        let output = module.forward(input);

        let expected = TensorData::from([
            [0.0000, 0.7746, 1.5492],
            [0.7348, 0.9798, 1.2247],
            [0.8514, 0.9933, 1.1352],
        ]);
        output.to_data().assert_approx_eq(&expected, 4);
    }

    #[test]
    fn display() {
        let config = RmsNormConfig::new(6);
        let layer_norm = config.init::<TestBackend>(&Default::default());

        assert_eq!(
            format!("{}", layer_norm),
            "RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
        );
    }
}