1use candle::{DType, Result, Tensor};
5
6#[derive(Clone, Debug)]
8pub struct GroupNorm {
9 weight: Tensor,
10 bias: Tensor,
11 eps: f64,
12 num_channels: usize,
13 num_groups: usize,
14}
15
16impl GroupNorm {
17 pub fn new(
18 weight: Tensor,
19 bias: Tensor,
20 num_channels: usize,
21 num_groups: usize,
22 eps: f64,
23 ) -> Result<Self> {
24 if !num_channels.is_multiple_of(num_groups) {
25 candle::bail!(
26 "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})"
27 )
28 }
29 Ok(Self {
30 weight,
31 bias,
32 eps,
33 num_channels,
34 num_groups,
35 })
36 }
37}
38
39impl crate::Module for GroupNorm {
40 fn forward(&self, x: &Tensor) -> Result<Tensor> {
41 let x_shape = x.dims();
42 if x_shape.len() <= 2 {
43 candle::bail!("input rank for GroupNorm should be at least 3");
44 }
45 let (b_sz, n_channels) = (x_shape[0], x_shape[1]);
46 let hidden_size = x_shape[2..].iter().product::<usize>() * n_channels / self.num_groups;
47 if n_channels != self.num_channels {
48 candle::bail!(
49 "unexpected num-channels in GroupNorm ({n_channels} <> {}",
50 self.num_channels
51 )
52 }
53 let x_dtype = x.dtype();
54 let internal_dtype = match x_dtype {
55 DType::F16 | DType::BF16 => DType::F32,
56 d => d,
57 };
58 let x = x.reshape((b_sz, self.num_groups, hidden_size))?;
59 let x = x.to_dtype(internal_dtype)?;
60 let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
61 let x = x.broadcast_sub(&mean_x)?;
62 let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
63 let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
64 let mut w_dims = vec![1; x_shape.len()];
65 w_dims[1] = n_channels;
66 let weight = self.weight.reshape(w_dims.clone())?;
67 let bias = self.bias.reshape(w_dims)?;
68 x_normed
69 .to_dtype(x_dtype)?
70 .reshape(x_shape)?
71 .broadcast_mul(&weight)?
72 .broadcast_add(&bias)
73 }
74}
75
76pub fn group_norm(
77 num_groups: usize,
78 num_channels: usize,
79 eps: f64,
80 vb: crate::VarBuilder,
81) -> Result<GroupNorm> {
82 let weight = vb.get_with_hints(num_channels, "weight", crate::Init::Const(1.))?;
83 let bias = vb.get_with_hints(num_channels, "bias", crate::Init::Const(0.))?;
84 GroupNorm::new(weight, bias, num_channels, num_groups, eps)
85}