aprender/nn/normalization/
group_norm.rs1#[allow(clippy::wildcard_imports)]
2use super::*;
3
4impl GroupNorm {
5 #[must_use]
16 pub fn new(num_groups: usize, num_channels: usize) -> Self {
17 assert!(
18 num_channels % num_groups == 0,
19 "num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
20 );
21
22 Self {
23 num_groups,
24 num_channels,
25 eps: 1e-5,
26 weight: constant(&[num_channels], 1.0).requires_grad(),
27 bias: zeros(&[num_channels]).requires_grad(),
28 affine: true,
29 }
30 }
31
32 #[must_use]
34 pub fn with_eps(num_groups: usize, num_channels: usize, eps: f32) -> Self {
35 let mut layer = Self::new(num_groups, num_channels);
36 layer.eps = eps;
37 layer
38 }
39
40 #[must_use]
42 pub fn without_affine(num_groups: usize, num_channels: usize) -> Self {
43 assert!(
44 num_channels % num_groups == 0,
45 "num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
46 );
47
48 Self {
49 num_groups,
50 num_channels,
51 eps: 1e-5,
52 weight: constant(&[num_channels], 1.0),
53 bias: zeros(&[num_channels]),
54 affine: false,
55 }
56 }
57
58 #[must_use]
60 pub fn num_groups(&self) -> usize {
61 self.num_groups
62 }
63
64 #[must_use]
66 pub fn num_channels(&self) -> usize {
67 self.num_channels
68 }
69}
70
71impl Module for GroupNorm {
72 fn forward(&self, input: &Tensor) -> Tensor {
73 let shape = input.shape();
74 assert!(
75 shape.len() >= 2,
76 "GroupNorm expects at least 2D input, got {}D",
77 shape.len()
78 );
79
80 let (batch_size, channels) = (shape[0], shape[1]);
81 assert_eq!(
82 channels, self.num_channels,
83 "Expected {} channels, got {}",
84 self.num_channels, channels
85 );
86
87 let channels_per_group = channels / self.num_groups;
88 let spatial_size: usize = shape[2..].iter().product();
89 let group_size = channels_per_group * spatial_size;
90
91 let input_data = input.data();
92 let mut output_data = vec![0.0; input_data.len()];
93
94 for n in 0..batch_size {
95 for g in 0..self.num_groups {
96 let mut sum = 0.0;
98
99 for c in 0..channels_per_group {
100 let channel_idx = g * channels_per_group + c;
101 for s in 0..spatial_size {
102 let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
103 sum += input_data[idx];
104 }
105 }
106
107 let mean = sum / group_size as f32;
108
109 let mut var_sum = 0.0;
110 for c in 0..channels_per_group {
111 let channel_idx = g * channels_per_group + c;
112 for s in 0..spatial_size {
113 let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
114 var_sum += (input_data[idx] - mean).powi(2);
115 }
116 }
117
118 let var = var_sum / group_size as f32;
119 let std_inv = 1.0 / (var + self.eps).sqrt();
120
121 for c in 0..channels_per_group {
123 let channel_idx = g * channels_per_group + c;
124 for s in 0..spatial_size {
125 let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
126 let normalized = (input_data[idx] - mean) * std_inv;
127
128 output_data[idx] = if self.affine {
129 normalized * self.weight.data()[channel_idx]
130 + self.bias.data()[channel_idx]
131 } else {
132 normalized
133 };
134 }
135 }
136 }
137 }
138
139 Tensor::new(&output_data, shape)
140 }
141
142 fn parameters(&self) -> Vec<&Tensor> {
143 if self.affine {
144 vec![&self.weight, &self.bias]
145 } else {
146 vec![]
147 }
148 }
149
150 fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
151 if self.affine {
152 vec![&mut self.weight, &mut self.bias]
153 } else {
154 vec![]
155 }
156 }
157}
158
159#[derive(Debug)]
188pub struct RMSNorm {
189 normalized_shape: Vec<usize>,
191 eps: f32,
193 weight: Tensor,
195 elementwise_affine: bool,
197}
198
199impl RMSNorm {
200 #[must_use]
206 pub fn new(normalized_shape: &[usize]) -> Self {
207 let numel: usize = normalized_shape.iter().product();
208 Self {
209 normalized_shape: normalized_shape.to_vec(),
210 eps: 1e-6, weight: constant(&[numel], 1.0).requires_grad(),
212 elementwise_affine: true,
213 }
214 }
215
216 #[must_use]
218 pub fn with_eps(normalized_shape: &[usize], eps: f32) -> Self {
219 let mut layer = Self::new(normalized_shape);
220 layer.eps = eps;
221 layer
222 }
223
224 #[must_use]
226 pub fn without_affine(normalized_shape: &[usize]) -> Self {
227 let numel: usize = normalized_shape.iter().product();
228 Self {
229 normalized_shape: normalized_shape.to_vec(),
230 eps: 1e-6,
231 weight: constant(&[numel], 1.0),
232 elementwise_affine: false,
233 }
234 }
235
236 #[must_use]
238 pub fn normalized_shape(&self) -> &[usize] {
239 &self.normalized_shape
240 }
241
242 #[must_use]
244 pub fn eps(&self) -> f32 {
245 self.eps
246 }
247
248 pub fn set_weight(&mut self, weight: Tensor) {
252 self.weight = weight;
253 }
254
255 #[must_use]
257 pub fn weight(&self) -> &Tensor {
258 &self.weight
259 }
260
261 #[must_use]
270 pub fn placeholder(normalized_shape: &[usize]) -> Self {
271 Self {
272 normalized_shape: normalized_shape.to_vec(),
273 eps: 1e-6,
274 weight: Tensor::new(&[1.0], &[1]),
275 elementwise_affine: true,
276 }
277 }
278}
279
280impl Module for RMSNorm {
281 #[provable_contracts_macros::contract("rmsnorm-kernel-v1", equation = "rmsnorm")]
284 fn forward(&self, input: &Tensor) -> Tensor {
285 let shape = input.shape();
286 let norm_size: usize = self.normalized_shape.iter().product();
287
288 assert!(
290 shape.len() >= self.normalized_shape.len(),
291 "Input must have at least as many dimensions as normalized_shape"
292 );
293
294 let start_dim = shape.len() - self.normalized_shape.len();
296 for (i, &ns) in self.normalized_shape.iter().enumerate() {
297 assert_eq!(
298 shape[start_dim + i],
299 ns,
300 "Input shape doesn't match normalized_shape at dim {i}"
301 );
302 }
303
304 if self.elementwise_affine {
305 crate::nn::functional::rms_norm(input, &self.weight, self.eps)
307 } else {
308 let batch_dims: usize = shape[..start_dim].iter().product();
310 let input_data = input.data();
311 let mut output_data = vec![0.0; input_data.len()];
312
313 for b in 0..batch_dims {
314 let offset = b * norm_size;
315 let slice = &input_data[offset..offset + norm_size];
316
317 let mean_sq: f32 = slice.iter().map(|&x| x * x).sum::<f32>() / norm_size as f32;
318 let rms_inv = 1.0 / (mean_sq + self.eps).sqrt();
319
320 for i in 0..norm_size {
321 output_data[offset + i] = slice[i] * rms_inv;
322 }
323 }
324
325 Tensor::new(&output_data, shape)
326 }
327 }
328
329 fn parameters(&self) -> Vec<&Tensor> {
330 if self.elementwise_affine {
331 vec![&self.weight]
332 } else {
333 vec![]
334 }
335 }
336
337 fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
338 if self.elementwise_affine {
339 vec![&mut self.weight]
340 } else {
341 vec![]
342 }
343 }
344}
345
346#[derive(Debug)]
353pub struct InstanceNorm {
354 inner: GroupNorm,
355}
356
357impl InstanceNorm {
358 #[must_use]
360 pub fn new(num_channels: usize) -> Self {
361 Self {
362 inner: GroupNorm::new(num_channels, num_channels),
363 }
364 }
365
366 #[must_use]
368 pub fn without_affine(num_channels: usize) -> Self {
369 Self {
370 inner: GroupNorm::without_affine(num_channels, num_channels),
371 }
372 }
373}
374
375impl Module for InstanceNorm {
376 fn forward(&self, input: &Tensor) -> Tensor {
377 self.inner.forward(input)
378 }
379
380 fn parameters(&self) -> Vec<&Tensor> {
381 self.inner.parameters()
382 }
383
384 fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
385 self.inner.parameters_mut()
386 }
387}
388
389#[cfg(test)]
390#[path = "tests.rs"]
391mod tests;