1use candle::{DType, Module, Result, Tensor, D};
32
33#[derive(Debug, Clone, Copy, PartialEq)]
34pub struct LayerNormConfig {
35 pub eps: f64,
36 pub remove_mean: bool,
39 pub affine: bool,
40}
41
42impl Default for LayerNormConfig {
43 fn default() -> Self {
44 Self {
45 eps: 1e-5,
46 remove_mean: true,
47 affine: true,
48 }
49 }
50}
51
52impl From<f64> for LayerNormConfig {
53 fn from(eps: f64) -> Self {
54 Self {
55 eps,
56 remove_mean: true,
57 affine: true,
58 }
59 }
60}
61
62#[derive(Clone, Debug)]
64pub struct LayerNorm {
65 weight: Tensor,
66 bias: Option<Tensor>,
67 remove_mean: bool,
68 eps: f64,
69}
70
71impl LayerNorm {
72 pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
73 Self {
74 weight,
75 bias: Some(bias),
76 remove_mean: true,
77 eps,
78 }
79 }
80
81 pub fn new_no_bias(weight: Tensor, eps: f64) -> Self {
82 Self {
83 weight,
84 bias: None,
85 remove_mean: true,
86 eps,
87 }
88 }
89
90 pub fn rms_norm(weight: Tensor, eps: f64) -> Self {
91 Self {
92 weight,
93 bias: None,
94 remove_mean: false,
95 eps,
96 }
97 }
98
99 pub fn weight(&self) -> &Tensor {
100 &self.weight
101 }
102
103 pub fn bias(&self) -> Option<&Tensor> {
104 self.bias.as_ref()
105 }
106
107 pub fn eps(&self) -> f64 {
108 self.eps
109 }
110
111 pub fn remove_mean(&self) -> bool {
112 self.remove_mean
113 }
114}
115
116impl Module for LayerNorm {
117 fn forward(&self, x: &Tensor) -> Result<Tensor> {
118 if x.is_contiguous() && self.remove_mean {
119 if let Some(bias) = self.bias.as_ref() {
120 return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
121 }
122 }
123 let x_dtype = x.dtype();
124 let internal_dtype = match x_dtype {
125 DType::F16 | DType::BF16 => DType::F32,
126 d => d,
127 };
128 let hidden_size = x.dim(D::Minus1)?;
129 let x = x.to_dtype(internal_dtype)?;
130 let x = if self.remove_mean {
131 let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
132 x.broadcast_sub(&mean_x)?
133 } else {
134 x
135 };
136 let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
137 let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
138 let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?;
139 match &self.bias {
140 None => Ok(x),
141 Some(bias) => x.broadcast_add(bias),
142 }
143 }
144}
145
146pub fn layer_norm<C: Into<LayerNormConfig>>(
147 size: usize,
148 config: C,
149 vb: crate::VarBuilder,
150) -> Result<LayerNorm> {
151 let config = config.into();
152 let weight = vb.get_with_hints(size, "weight", crate::Init::Const(1.))?;
153 let bias = if config.affine {
154 Some(vb.get_with_hints(size, "bias", crate::Init::Const(0.))?)
155 } else {
156 None
157 };
158 Ok(LayerNorm {
159 weight,
160 bias,
161 remove_mean: config.remove_mean,
162 eps: config.eps,
163 })
164}
165
166pub fn layer_norm_no_bias(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
167 let config = LayerNormConfig {
168 eps,
169 remove_mean: true,
170 affine: false,
171 };
172 layer_norm(size, config, vb)
173}
174
175#[derive(Clone, Debug)]
177pub struct RmsNorm(LayerNorm);
178
179impl RmsNorm {
180 pub fn new(weight: Tensor, eps: f64) -> Self {
181 Self(LayerNorm::rms_norm(weight, eps))
182 }
183
184 pub fn into_inner(self) -> LayerNorm {
185 self.0
186 }
187
188 pub fn weight(&self) -> &Tensor {
189 self.0.weight()
190 }
191
192 pub fn eps(&self) -> f64 {
193 self.0.eps()
194 }
195
196 pub fn forward_diff(&self, xs: &Tensor) -> Result<Tensor> {
198 self.0.forward(xs)
199 }
200}
201
202impl Module for RmsNorm {
203 fn forward(&self, xs: &Tensor) -> Result<Tensor> {
204 if xs.is_contiguous() {
205 crate::ops::rms_norm(xs, &self.0.weight, self.0.eps as f32)
206 } else {
207 self.0.forward(xs)
208 }
209 }
210}
211
212pub fn rms_norm(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<RmsNorm> {
213 let config = LayerNormConfig {
214 eps,
215 remove_mean: false,
216 affine: false,
217 };
218 Ok(RmsNorm(layer_norm(size, config, vb)?))
219}