candle_nn/
group_norm.rs

1//! Group Normalization.
2//!
3//! This layer applies Group Normalization over a mini-batch of inputs.
4use candle::{DType, Result, Tensor};
5
6// This group norm version handles both weight and bias so removes the mean.
7#[derive(Clone, Debug)]
8pub struct GroupNorm {
9    weight: Tensor,
10    bias: Tensor,
11    eps: f64,
12    num_channels: usize,
13    num_groups: usize,
14}
15
16impl GroupNorm {
17    pub fn new(
18        weight: Tensor,
19        bias: Tensor,
20        num_channels: usize,
21        num_groups: usize,
22        eps: f64,
23    ) -> Result<Self> {
24        if !num_channels.is_multiple_of(num_groups) {
25            candle::bail!(
26                "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})"
27            )
28        }
29        Ok(Self {
30            weight,
31            bias,
32            eps,
33            num_channels,
34            num_groups,
35        })
36    }
37}
38
39impl crate::Module for GroupNorm {
40    fn forward(&self, x: &Tensor) -> Result<Tensor> {
41        let x_shape = x.dims();
42        if x_shape.len() <= 2 {
43            candle::bail!("input rank for GroupNorm should be at least 3");
44        }
45        let (b_sz, n_channels) = (x_shape[0], x_shape[1]);
46        let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups;
47        if n_channels != self.num_channels {
48            candle::bail!(
49                "unexpected num-channels in GroupNorm ({n_channels} <> {}",
50                self.num_channels
51            )
52        }
53        let x_dtype = x.dtype();
54        let internal_dtype = match x_dtype {
55            DType::F16 | DType::BF16 => DType::F32,
56            d => d,
57        };
58        let x = x.reshape((b_sz, self.num_groups, hidden_size))?;
59        let x = x.to_dtype(internal_dtype)?;
60        let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
61        let x = x.broadcast_sub(&mean_x)?;
62        let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
63        let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
64        let mut w_dims = vec![1; x_shape.len()];
65        w_dims[1] = n_channels;
66        let weight = self.weight.reshape(w_dims.clone())?;
67        let bias = self.bias.reshape(w_dims)?;
68        x_normed
69            .to_dtype(x_dtype)?
70            .reshape(x_shape)?
71            .broadcast_mul(&weight)?
72            .broadcast_add(&bias)
73    }
74}
75
76pub fn group_norm(
77    num_groups: usize,
78    num_channels: usize,
79    eps: f64,
80    vb: crate::VarBuilder,
81) -> Result<GroupNorm> {
82    let weight = vb.get_with_hints(num_channels, "weight", crate::Init::Const(1.))?;
83    let bias = vb.get_with_hints(num_channels, "bias", crate::Init::Const(0.))?;
84    GroupNorm::new(weight, bias, num_channels, num_groups, eps)
85}