1use burn_tensor::DType;
2
3use crate as burn;
4
5use crate::config::Config;
6use crate::module::Module;
7use crate::module::Param;
8use crate::module::{Content, DisplaySettings, ModuleDisplay};
9use crate::nn::Initializer;
10use crate::tensor::backend::Backend;
11use crate::tensor::Tensor;
12
13#[derive(Config)]
15pub struct RmsNormConfig {
16 pub d_model: usize,
18 #[config(default = 1e-5)]
20 pub epsilon: f64,
21}
22
23impl RmsNormConfig {
24 pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
30 assert!(self.epsilon > 0.0, "epsilon must be positive.");
31
32 let gamma = Initializer::Ones.init([self.d_model], device);
33
34 RmsNorm {
35 gamma,
36 epsilon: self.epsilon,
37 }
38 }
39}
40
41#[derive(Module, Debug)]
54#[module(custom_display)]
55pub struct RmsNorm<B: Backend> {
56 pub gamma: Param<Tensor<B, 1>>,
58 pub epsilon: f64,
60}
61
62impl<B: Backend> RmsNorm<B> {
63 pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
72 let dtype = x.dtype();
74 let rms =
75 (x.clone().cast(DType::F32).powf_scalar(2.0).mean_dim(D - 1) + self.epsilon).sqrt();
76 (x / rms.cast(dtype)) * self.gamma.val().unsqueeze()
77 }
78}
79
80impl<B: Backend> ModuleDisplay for RmsNorm<B> {
81 fn custom_settings(&self) -> Option<DisplaySettings> {
82 DisplaySettings::new()
83 .with_new_line_after_attribute(false)
84 .optional()
85 }
86
87 fn custom_content(&self, content: Content) -> Option<Content> {
88 let [d_model] = self.gamma.shape().dims();
89 content
90 .add("d_model", &d_model)
91 .add("epsilon", &self.epsilon)
92 .optional()
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::tensor::TensorData;
100 use crate::TestBackend;
101 use alloc::format;
102
103 #[test]
104 fn rms_norm_forward() {
105 let device = Default::default();
106 let module = RmsNormConfig::new(3)
107 .with_epsilon(1e-5)
108 .init::<TestBackend>(&device);
109
110 let input = Tensor::arange(0..9, &device).float().reshape([3, 3]);
111
112 let output = module.forward(input);
113
114 let expected = TensorData::from([
115 [0.0000, 0.7746, 1.5492],
116 [0.7348, 0.9798, 1.2247],
117 [0.8514, 0.9933, 1.1352],
118 ]);
119 output.to_data().assert_approx_eq(&expected, 4);
120 }
121
122 #[test]
123 fn display() {
124 let config = RmsNormConfig::new(6);
125 let layer_norm = config.init::<TestBackend>(&Default::default());
126
127 assert_eq!(
128 format!("{}", layer_norm),
129 "RmsNorm {d_model: 6, epsilon: 0.00001, params: 6}"
130 );
131 }
132}