1use candle::{DType, Result, Tensor, Var};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct BatchNormConfig {
14 pub eps: f64,
15 pub remove_mean: bool,
16
17 pub affine: bool,
20
21 pub momentum: f64,
25}
26
27impl Default for BatchNormConfig {
28 fn default() -> Self {
29 Self {
30 eps: 1e-5,
31 remove_mean: true,
32 affine: true,
33 momentum: 0.1,
34 }
35 }
36}
37
38impl From<f64> for BatchNormConfig {
39 fn from(eps: f64) -> Self {
40 Self {
41 eps,
42 ..Default::default()
43 }
44 }
45}
46
47#[derive(Clone, Debug)]
48pub struct BatchNorm {
49 running_mean: Var,
50 running_var: Var,
51 weight_and_bias: Option<(Tensor, Tensor)>,
52 remove_mean: bool,
53 eps: f64,
54 momentum: f64,
55}
56
57impl BatchNorm {
58 fn check_validity(&self, num_features: usize) -> Result<()> {
59 if self.eps < 0. {
60 candle::bail!("batch-norm eps cannot be negative {}", self.eps)
61 }
62 if !(0.0..=1.0).contains(&self.momentum) {
63 candle::bail!(
64 "batch-norm momentum must be between 0 and 1, is {}",
65 self.momentum
66 )
67 }
68 if self.running_mean.dims() != [num_features] {
69 candle::bail!(
70 "batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]",
71 self.running_mean.shape(),
72 )
73 }
74 if self.running_var.dims() != [num_features] {
75 candle::bail!(
76 "batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]",
77 self.running_var.shape(),
78 )
79 }
80 if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() {
81 if weight.dims() != [num_features] {
82 candle::bail!(
83 "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
84 weight.shape(),
85 )
86 }
87 if bias.dims() != [num_features] {
88 candle::bail!(
89 "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
90 bias.shape(),
91 )
92 }
93 }
94 Ok(())
95 }
96
97 pub fn new(
98 num_features: usize,
99 running_mean: Tensor,
100 running_var: Tensor,
101 weight: Tensor,
102 bias: Tensor,
103 eps: f64,
104 ) -> Result<Self> {
105 let out = Self {
106 running_mean: Var::from_tensor(&running_mean)?,
107 running_var: Var::from_tensor(&running_var)?,
108 weight_and_bias: Some((weight, bias)),
109 remove_mean: true,
110 eps,
111 momentum: 0.1,
112 };
113 out.check_validity(num_features)?;
114 Ok(out)
115 }
116
117 pub fn new_no_bias(
118 num_features: usize,
119 running_mean: Tensor,
120 running_var: Tensor,
121 eps: f64,
122 ) -> Result<Self> {
123 let out = Self {
124 running_mean: Var::from_tensor(&running_mean)?,
125 running_var: Var::from_tensor(&running_var)?,
126 weight_and_bias: None,
127 remove_mean: true,
128 eps,
129 momentum: 0.1,
130 };
131 out.check_validity(num_features)?;
132 Ok(out)
133 }
134
135 pub fn new_with_momentum(
136 num_features: usize,
137 running_mean: Tensor,
138 running_var: Tensor,
139 weight: Tensor,
140 bias: Tensor,
141 eps: f64,
142 momentum: f64,
143 ) -> Result<Self> {
144 let out = Self {
145 running_mean: Var::from_tensor(&running_mean)?,
146 running_var: Var::from_tensor(&running_var)?,
147 weight_and_bias: Some((weight, bias)),
148 remove_mean: true,
149 eps,
150 momentum,
151 };
152 out.check_validity(num_features)?;
153 Ok(out)
154 }
155
156 pub fn new_no_bias_with_momentum(
157 num_features: usize,
158 running_mean: Tensor,
159 running_var: Tensor,
160 eps: f64,
161 momentum: f64,
162 ) -> Result<Self> {
163 let out = Self {
164 running_mean: Var::from_tensor(&running_mean)?,
165 running_var: Var::from_tensor(&running_var)?,
166 weight_and_bias: None,
167 remove_mean: true,
168 eps,
169 momentum,
170 };
171 out.check_validity(num_features)?;
172 Ok(out)
173 }
174
175 pub fn running_mean(&self) -> &Tensor {
176 self.running_mean.as_tensor()
177 }
178
179 pub fn running_var(&self) -> &Tensor {
180 self.running_var.as_tensor()
181 }
182
183 pub fn eps(&self) -> f64 {
184 self.eps
185 }
186
187 pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
188 self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
189 }
190
191 pub fn momentum(&self) -> f64 {
192 self.momentum
193 }
194
195 pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
196 let num_features = self.running_mean.as_tensor().dim(0)?;
197 let x_dtype = x.dtype();
198 let internal_dtype = match x_dtype {
199 DType::F16 | DType::BF16 => DType::F32,
200 d => d,
201 };
202 if x.rank() < 2 {
203 candle::bail!(
204 "batch-norm input tensor must have at least two dimensions ({:?})",
205 x.shape()
206 )
207 }
208 if x.dim(1)? != num_features {
209 candle::bail!(
210 "batch-norm input doesn't have the expected number of features ({:?} <> {})",
211 x.shape(),
212 num_features
213 )
214 }
215 let x = x.to_dtype(internal_dtype)?;
216 let x = x.transpose(0, 1)?;
217 let x_dims_post_transpose = x.dims();
218 let x = x.flatten_from(1)?.contiguous()?;
221 let x = if self.remove_mean {
222 let mean_x = x.mean_keepdim(1)?;
224 let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
225 + (mean_x.flatten_all()? * self.momentum)?)?;
226 self.running_mean.set(&updated_running_mean)?;
227 x.broadcast_sub(&mean_x)?
228 } else {
229 x
230 };
231 let norm_x = x.sqr()?.mean_keepdim(1)?;
233 let updated_running_var = {
234 let batch_size = x.dim(1)? as f64;
235 let running_var_weight = 1.0 - self.momentum;
236 let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
237 ((self.running_var.as_tensor() * running_var_weight)?
238 + (&norm_x.flatten_all()? * norm_x_weight)?)?
239 };
240 self.running_var.set(&updated_running_var)?;
241 let x = x
242 .broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
243 .to_dtype(x_dtype)?;
244 let x = match &self.weight_and_bias {
245 None => x,
246 Some((weight, bias)) => {
247 let weight = weight.reshape(((), 1))?;
248 let bias = bias.reshape(((), 1))?;
249 x.broadcast_mul(&weight)?.broadcast_add(&bias)?
250 }
251 };
252 x.reshape(x_dims_post_transpose)?.transpose(0, 1)
253 }
254
255 fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
256 let target_shape: Vec<usize> = x
257 .dims()
258 .iter()
259 .enumerate()
260 .map(|(idx, v)| if idx == 1 { *v } else { 1 })
261 .collect();
262 let target_shape = target_shape.as_slice();
263
264 let x = x
265 .broadcast_sub(
266 &self
267 .running_mean
268 .as_detached_tensor()
269 .reshape(target_shape)?,
270 )?
271 .broadcast_div(
272 &(self
273 .running_var
274 .as_detached_tensor()
275 .reshape(target_shape)?
276 + self.eps)?
277 .sqrt()?,
278 )?;
279
280 match &self.weight_and_bias {
281 None => Ok(x),
282 Some((weight, bias)) => {
283 let weight = weight.reshape(target_shape)?;
284 let bias = bias.reshape(target_shape)?;
285 x.broadcast_mul(&weight)?.broadcast_add(&bias)
286 }
287 }
288 }
289}
290
291impl crate::ModuleT for BatchNorm {
292 fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
293 if train {
294 self.forward_train(x)
295 } else {
296 self.forward_eval(x)
297 }
298 }
299}
300
301pub fn batch_norm<C: Into<BatchNormConfig>>(
302 num_features: usize,
303 config: C,
304 vb: crate::VarBuilder,
305) -> Result<BatchNorm> {
306 use crate::Init;
307 let config = config.into();
308 if config.eps < 0. {
309 candle::bail!("batch-norm eps cannot be negative {}", config.eps)
310 }
311 let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?;
312 let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?;
313 let weight_and_bias = if config.affine {
314 let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?;
315 let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?;
316 Some((weight, bias))
317 } else {
318 None
319 };
320 Ok(BatchNorm {
321 running_mean: Var::from_tensor(&running_mean)?,
322 running_var: Var::from_tensor(&running_var)?,
323 weight_and_bias,
324 remove_mean: config.remove_mean,
325 eps: config.eps,
326 momentum: config.momentum,
327 })
328}