1use burn_tensor::DType;
2
3use crate as burn;
4
5use crate::config::Config;
6use crate::module::Module;
7use crate::module::Param;
8use crate::module::{Content, DisplaySettings, ModuleDisplay};
9use crate::nn::Initializer;
10use crate::tensor::Tensor;
11use crate::tensor::backend::Backend;
12
13#[derive(Config)]
15pub struct RmsNormConfig {
16 pub d_model: usize,
18 #[config(default = 1e-5)]
20 pub epsilon: f64,
21}
22
23impl RmsNormConfig {
24 pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
30 assert!(self.epsilon > 0.0, "epsilon must be positive.");
31
32 let gamma = Initializer::Ones.init([self.d_model], device);
33
34 RmsNorm {
35 gamma,
36 epsilon: self.epsilon,
37 }
38 }
39}
40
41#[derive(Module, Debug)]
54#[module(custom_display)]
55pub struct RmsNorm<B: Backend> {
56 pub gamma: Param<Tensor<B, 1>>,
58 pub epsilon: f64,
60}
61
62impl<B: Backend> RmsNorm<B> {
63 pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
72 let dtype = x.dtype();
74 let rms =
75 (x.clone().cast(DType::F32).powf_scalar(2.0).mean_dim(D - 1) + self.epsilon).sqrt();
76 (x / rms.cast(dtype)) * self.gamma.val().unsqueeze()
77 }
78}
79
80impl<B: Backend> ModuleDisplay for RmsNorm<B> {
81 fn custom_settings(&self) -> Option<DisplaySettings> {
82 DisplaySettings::new()
83 .with_new_line_after_attribute(false)
84 .optional()
85 }
86
87 fn custom_content(&self, content: Content) -> Option<Content> {
88 let [d_model] = self.gamma.shape().dims();
89 content
90 .add("d_model", &d_model)
91 .add("epsilon", &self.epsilon)
92 .optional()
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::TestBackend;
100 use crate::tensor::TensorData;
101 use alloc::format;
102 use burn_tensor::{Tolerance, ops::FloatElem};
103 type FT = FloatElem<TestBackend>;
104
105 #[test]
106 fn rms_norm_forward() {
107 let device = Default::default();
108 let module = RmsNormConfig::new(3)
109 .with_epsilon(1e-5)
110 .init::<TestBackend>(&device);
111
112 let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);
113
114 let output = module.forward(input);
115
116 let expected = TensorData::from([
117 [0.0000, 0.7746, 1.5492],
118 [0.7348, 0.9798, 1.2247],
119 [0.8514, 0.9933, 1.1352],
120 ]);
121 output
122 .to_data()
123 .assert_approx_eq::<FT>(&expected, Tolerance::rel_abs(1e-4, 1e-4));
124 }
125
126 #[test]
127 fn display() {
128 let config = RmsNormConfig::new(6);
129 let layer_norm = config.init::<TestBackend>(&Default::default());
130
131 assert_eq!(
132 format!("{}", layer_norm),
133 "RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
134 );
135 }
136}