candle_nn/
batch_norm.rs

1//! Batch Normalization.
2//!
3//! This layer applies Batch Normalization over a mini-batch of inputs as described in [`Batch
4//! Normalization`]. The input is expected to have at least three dimensions.
5//!
6//! Note that this implementation is for inference only, there is no possibility to track the
7//! running stats.
8//!
9//! [`Batch Normalization`]: https://arxiv.org/abs/1502.03167
10use candle::{DType, Result, Tensor, Var};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub struct BatchNormConfig {
14    pub eps: f64,
15    pub remove_mean: bool,
16
17    /// The meaning of affine here is different from LayerNorm: when false there is no learnable
18    /// parameter at all, 1 used for gamma and 0 for beta.
19    pub affine: bool,
20
21    /// Controls exponential moving average of running stats. Defaults to 0.1
22    ///
23    /// `running_stat * (1.0 - momentum) + stat * momentum`.
24    pub momentum: f64,
25}
26
27impl Default for BatchNormConfig {
28    fn default() -> Self {
29        Self {
30            eps: 1e-5,
31            remove_mean: true,
32            affine: true,
33            momentum: 0.1,
34        }
35    }
36}
37
38impl From<f64> for BatchNormConfig {
39    fn from(eps: f64) -> Self {
40        Self {
41            eps,
42            ..Default::default()
43        }
44    }
45}
46
47#[derive(Clone, Debug)]
48pub struct BatchNorm {
49    running_mean: Var,
50    running_var: Var,
51    weight_and_bias: Option<(Tensor, Tensor)>,
52    remove_mean: bool,
53    eps: f64,
54    momentum: f64,
55}
56
57impl BatchNorm {
58    fn check_validity(&self, num_features: usize) -> Result<()> {
59        if self.eps < 0. {
60            candle::bail!("batch-norm eps cannot be negative {}", self.eps)
61        }
62        if !(0.0..=1.0).contains(&self.momentum) {
63            candle::bail!(
64                "batch-norm momentum must be between 0 and 1, is {}",
65                self.momentum
66            )
67        }
68        if self.running_mean.dims() != [num_features] {
69            candle::bail!(
70                "batch-norm running mean has unexpected shape {:?} should have shape [{num_features}]",
71                self.running_mean.shape(),
72            )
73        }
74        if self.running_var.dims() != [num_features] {
75            candle::bail!(
76                "batch-norm running variance has unexpected shape {:?} should have shape [{num_features}]",
77                self.running_var.shape(),
78            )
79        }
80        if let Some((ref weight, ref bias)) = self.weight_and_bias.as_ref() {
81            if weight.dims() != [num_features] {
82                candle::bail!(
83                    "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
84                    weight.shape(),
85                )
86            }
87            if bias.dims() != [num_features] {
88                candle::bail!(
89                    "batch-norm weight has unexpected shape {:?} should have shape [{num_features}]",
90                    bias.shape(),
91                )
92            }
93        }
94        Ok(())
95    }
96
97    pub fn new(
98        num_features: usize,
99        running_mean: Tensor,
100        running_var: Tensor,
101        weight: Tensor,
102        bias: Tensor,
103        eps: f64,
104    ) -> Result<Self> {
105        let out = Self {
106            running_mean: Var::from_tensor(&running_mean)?,
107            running_var: Var::from_tensor(&running_var)?,
108            weight_and_bias: Some((weight, bias)),
109            remove_mean: true,
110            eps,
111            momentum: 0.1,
112        };
113        out.check_validity(num_features)?;
114        Ok(out)
115    }
116
117    pub fn new_no_bias(
118        num_features: usize,
119        running_mean: Tensor,
120        running_var: Tensor,
121        eps: f64,
122    ) -> Result<Self> {
123        let out = Self {
124            running_mean: Var::from_tensor(&running_mean)?,
125            running_var: Var::from_tensor(&running_var)?,
126            weight_and_bias: None,
127            remove_mean: true,
128            eps,
129            momentum: 0.1,
130        };
131        out.check_validity(num_features)?;
132        Ok(out)
133    }
134
135    pub fn new_with_momentum(
136        num_features: usize,
137        running_mean: Tensor,
138        running_var: Tensor,
139        weight: Tensor,
140        bias: Tensor,
141        eps: f64,
142        momentum: f64,
143    ) -> Result<Self> {
144        let out = Self {
145            running_mean: Var::from_tensor(&running_mean)?,
146            running_var: Var::from_tensor(&running_var)?,
147            weight_and_bias: Some((weight, bias)),
148            remove_mean: true,
149            eps,
150            momentum,
151        };
152        out.check_validity(num_features)?;
153        Ok(out)
154    }
155
156    pub fn new_no_bias_with_momentum(
157        num_features: usize,
158        running_mean: Tensor,
159        running_var: Tensor,
160        eps: f64,
161        momentum: f64,
162    ) -> Result<Self> {
163        let out = Self {
164            running_mean: Var::from_tensor(&running_mean)?,
165            running_var: Var::from_tensor(&running_var)?,
166            weight_and_bias: None,
167            remove_mean: true,
168            eps,
169            momentum,
170        };
171        out.check_validity(num_features)?;
172        Ok(out)
173    }
174
175    pub fn running_mean(&self) -> &Tensor {
176        self.running_mean.as_tensor()
177    }
178
179    pub fn running_var(&self) -> &Tensor {
180        self.running_var.as_tensor()
181    }
182
183    pub fn eps(&self) -> f64 {
184        self.eps
185    }
186
187    pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
188        self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
189    }
190
191    pub fn momentum(&self) -> f64 {
192        self.momentum
193    }
194
195    pub fn forward_train(&self, x: &Tensor) -> Result<Tensor> {
196        let num_features = self.running_mean.as_tensor().dim(0)?;
197        let x_dtype = x.dtype();
198        let internal_dtype = match x_dtype {
199            DType::F16 | DType::BF16 => DType::F32,
200            d => d,
201        };
202        if x.rank() < 2 {
203            candle::bail!(
204                "batch-norm input tensor must have at least two dimensions ({:?})",
205                x.shape()
206            )
207        }
208        if x.dim(1)? != num_features {
209            candle::bail!(
210                "batch-norm input doesn't have the expected number of features ({:?} <> {})",
211                x.shape(),
212                num_features
213            )
214        }
215        let x = x.to_dtype(internal_dtype)?;
216        let x = x.transpose(0, 1)?;
217        let x_dims_post_transpose = x.dims();
218        // Flatten all the dimensions exception the channel one as this performs a Spatial Batch
219        // Normalization.
220        let x = x.flatten_from(1)?.contiguous()?;
221        let x = if self.remove_mean {
222            // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
223            let mean_x = x.mean_keepdim(1)?;
224            let updated_running_mean = ((self.running_mean.as_tensor() * (1.0 - self.momentum))?
225                + (mean_x.flatten_all()? * self.momentum)?)?;
226            self.running_mean.set(&updated_running_mean)?;
227            x.broadcast_sub(&mean_x)?
228        } else {
229            x
230        };
231        // The mean is taken over dim 1 as this is the batch dim after the transpose(0, 1) above.
232        let norm_x = x.sqr()?.mean_keepdim(1)?;
233        let updated_running_var = {
234            let batch_size = x.dim(1)? as f64;
235            let running_var_weight = 1.0 - self.momentum;
236            let norm_x_weight = self.momentum * batch_size / (batch_size - 1.0);
237            ((self.running_var.as_tensor() * running_var_weight)?
238                + (&norm_x.flatten_all()? * norm_x_weight)?)?
239        };
240        self.running_var.set(&updated_running_var)?;
241        let x = x
242            .broadcast_div(&(norm_x + self.eps)?.sqrt()?)?
243            .to_dtype(x_dtype)?;
244        let x = match &self.weight_and_bias {
245            None => x,
246            Some((weight, bias)) => {
247                let weight = weight.reshape(((), 1))?;
248                let bias = bias.reshape(((), 1))?;
249                x.broadcast_mul(&weight)?.broadcast_add(&bias)?
250            }
251        };
252        x.reshape(x_dims_post_transpose)?.transpose(0, 1)
253    }
254
255    fn forward_eval(&self, x: &Tensor) -> Result<Tensor> {
256        let target_shape: Vec<usize> = x
257            .dims()
258            .iter()
259            .enumerate()
260            .map(|(idx, v)| if idx == 1 { *v } else { 1 })
261            .collect();
262        let target_shape = target_shape.as_slice();
263
264        let x = x
265            .broadcast_sub(
266                &self
267                    .running_mean
268                    .as_detached_tensor()
269                    .reshape(target_shape)?,
270            )?
271            .broadcast_div(
272                &(self
273                    .running_var
274                    .as_detached_tensor()
275                    .reshape(target_shape)?
276                    + self.eps)?
277                    .sqrt()?,
278            )?;
279
280        match &self.weight_and_bias {
281            None => Ok(x),
282            Some((weight, bias)) => {
283                let weight = weight.reshape(target_shape)?;
284                let bias = bias.reshape(target_shape)?;
285                x.broadcast_mul(&weight)?.broadcast_add(&bias)
286            }
287        }
288    }
289}
290
291impl crate::ModuleT for BatchNorm {
292    fn forward_t(&self, x: &Tensor, train: bool) -> Result<Tensor> {
293        if train {
294            self.forward_train(x)
295        } else {
296            self.forward_eval(x)
297        }
298    }
299}
300
301pub fn batch_norm<C: Into<BatchNormConfig>>(
302    num_features: usize,
303    config: C,
304    vb: crate::VarBuilder,
305) -> Result<BatchNorm> {
306    use crate::Init;
307    let config = config.into();
308    if config.eps < 0. {
309        candle::bail!("batch-norm eps cannot be negative {}", config.eps)
310    }
311    let running_mean = vb.get_with_hints(num_features, "running_mean", Init::Const(0.))?;
312    let running_var = vb.get_with_hints(num_features, "running_var", Init::Const(1.))?;
313    let weight_and_bias = if config.affine {
314        let weight = vb.get_with_hints(num_features, "weight", Init::Const(1.))?;
315        let bias = vb.get_with_hints(num_features, "bias", Init::Const(0.))?;
316        Some((weight, bias))
317    } else {
318        None
319    };
320    Ok(BatchNorm {
321        running_mean: Var::from_tensor(&running_mean)?,
322        running_var: Var::from_tensor(&running_var)?,
323        weight_and_bias,
324        remove_mean: config.remove_mean,
325        eps: config.eps,
326        momentum: config.momentum,
327    })
328}