burn_nn/modules/norm/
rms.rs1use burn::tensor::DType;
2
3use burn_core as burn;
4
5use burn::config::Config;
6use burn::module::Initializer;
7use burn::module::Module;
8use burn::module::Param;
9use burn::module::{Content, DisplaySettings, ModuleDisplay};
10use burn::tensor::Tensor;
11use burn::tensor::backend::Backend;
12
13#[derive(Config, Debug)]
15pub struct RmsNormConfig {
16 pub d_model: usize,
18 #[config(default = 1e-5)]
20 pub epsilon: f64,
21}
22
23impl RmsNormConfig {
24 pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
30 assert!(self.epsilon > 0.0, "epsilon must be positive.");
31
32 let gamma = Initializer::Ones.init([self.d_model], device);
33
34 RmsNorm {
35 gamma,
36 epsilon: self.epsilon,
37 }
38 }
39}
40
41#[derive(Module, Debug)]
54#[module(custom_display)]
55pub struct RmsNorm<B: Backend> {
56 pub gamma: Param<Tensor<B, 1>>,
58 pub epsilon: f64,
60}
61
62impl<B: Backend> RmsNorm<B> {
63 pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
72 let dtype = x.dtype();
74 let rms = (x.clone().cast(DType::F32).square().mean_dim(D - 1) + self.epsilon).sqrt();
75 (x / rms.cast(dtype)) * self.gamma.val().unsqueeze()
76 }
77}
78
79impl<B: Backend> ModuleDisplay for RmsNorm<B> {
80 fn custom_settings(&self) -> Option<DisplaySettings> {
81 DisplaySettings::new()
82 .with_new_line_after_attribute(false)
83 .optional()
84 }
85
86 fn custom_content(&self, content: Content) -> Option<Content> {
87 let [d_model] = self.gamma.shape().dims();
88 content
89 .add("d_model", &d_model)
90 .add("epsilon", &self.epsilon)
91 .optional()
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use crate::TestBackend;
99 use alloc::format;
100 use burn::tensor::TensorData;
101 use burn::tensor::{Tolerance, ops::FloatElem};
102 type FT = FloatElem<TestBackend>;
103
104 #[test]
105 fn rms_norm_forward() {
106 let device = Default::default();
107 let module = RmsNormConfig::new(3)
108 .with_epsilon(1e-5)
109 .init::<TestBackend>(&device);
110
111 let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);
112
113 let output = module.forward(input);
114
115 let expected = TensorData::from([
116 [0.0000, 0.7746, 1.5492],
117 [0.7348, 0.9798, 1.2247],
118 [0.8514, 0.9933, 1.1352],
119 ]);
120 output
121 .to_data()
122 .assert_approx_eq::<FT>(&expected, Tolerance::default());
123 }
124
125 #[test]
126 fn display() {
127 let config = RmsNormConfig::new(6);
128 let layer_norm = config.init::<TestBackend>(&Default::default());
129
130 assert_eq!(
131 format!("{layer_norm}"),
132 "RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
133 );
134 }
135}