Skip to main content

aprender/nn/normalization/
group_norm.rs

1#[allow(clippy::wildcard_imports)]
2use super::*;
3
4impl GroupNorm {
5    /// Create a new `GroupNorm` layer.
6    ///
7    /// # Arguments
8    ///
9    /// * `num_groups` - Number of groups to divide channels into
10    /// * `num_channels` - Number of channels (must be divisible by `num_groups`)
11    ///
12    /// # Panics
13    ///
14    /// Panics if `num_channels` is not divisible by `num_groups`.
15    #[must_use]
16    pub fn new(num_groups: usize, num_channels: usize) -> Self {
17        assert!(
18            num_channels % num_groups == 0,
19            "num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
20        );
21
22        Self {
23            num_groups,
24            num_channels,
25            eps: 1e-5,
26            weight: constant(&[num_channels], 1.0).requires_grad(),
27            bias: zeros(&[num_channels]).requires_grad(),
28            affine: true,
29        }
30    }
31
32    /// Create `GroupNorm` with custom epsilon.
33    #[must_use]
34    pub fn with_eps(num_groups: usize, num_channels: usize, eps: f32) -> Self {
35        let mut layer = Self::new(num_groups, num_channels);
36        layer.eps = eps;
37        layer
38    }
39
40    /// Create `GroupNorm` without learnable parameters.
41    #[must_use]
42    pub fn without_affine(num_groups: usize, num_channels: usize) -> Self {
43        assert!(
44            num_channels % num_groups == 0,
45            "num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
46        );
47
48        Self {
49            num_groups,
50            num_channels,
51            eps: 1e-5,
52            weight: constant(&[num_channels], 1.0),
53            bias: zeros(&[num_channels]),
54            affine: false,
55        }
56    }
57
58    /// Get number of groups.
59    #[must_use]
60    pub fn num_groups(&self) -> usize {
61        self.num_groups
62    }
63
64    /// Get number of channels.
65    #[must_use]
66    pub fn num_channels(&self) -> usize {
67        self.num_channels
68    }
69}
70
71impl Module for GroupNorm {
72    fn forward(&self, input: &Tensor) -> Tensor {
73        let shape = input.shape();
74        assert!(
75            shape.len() >= 2,
76            "GroupNorm expects at least 2D input, got {}D",
77            shape.len()
78        );
79
80        let (batch_size, channels) = (shape[0], shape[1]);
81        assert_eq!(
82            channels, self.num_channels,
83            "Expected {} channels, got {}",
84            self.num_channels, channels
85        );
86
87        let channels_per_group = channels / self.num_groups;
88        let spatial_size: usize = shape[2..].iter().product();
89        let group_size = channels_per_group * spatial_size;
90
91        let input_data = input.data();
92        let mut output_data = vec![0.0; input_data.len()];
93
94        for n in 0..batch_size {
95            for g in 0..self.num_groups {
96                // Compute mean and variance for this group
97                let mut sum = 0.0;
98
99                for c in 0..channels_per_group {
100                    let channel_idx = g * channels_per_group + c;
101                    for s in 0..spatial_size {
102                        let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
103                        sum += input_data[idx];
104                    }
105                }
106
107                let mean = sum / group_size as f32;
108
109                let mut var_sum = 0.0;
110                for c in 0..channels_per_group {
111                    let channel_idx = g * channels_per_group + c;
112                    for s in 0..spatial_size {
113                        let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
114                        var_sum += (input_data[idx] - mean).powi(2);
115                    }
116                }
117
118                let var = var_sum / group_size as f32;
119                let std_inv = 1.0 / (var + self.eps).sqrt();
120
121                // Normalize and apply affine transformation
122                for c in 0..channels_per_group {
123                    let channel_idx = g * channels_per_group + c;
124                    for s in 0..spatial_size {
125                        let idx = n * channels * spatial_size + channel_idx * spatial_size + s;
126                        let normalized = (input_data[idx] - mean) * std_inv;
127
128                        output_data[idx] = if self.affine {
129                            normalized * self.weight.data()[channel_idx]
130                                + self.bias.data()[channel_idx]
131                        } else {
132                            normalized
133                        };
134                    }
135                }
136            }
137        }
138
139        Tensor::new(&output_data, shape)
140    }
141
142    fn parameters(&self) -> Vec<&Tensor> {
143        if self.affine {
144            vec![&self.weight, &self.bias]
145        } else {
146            vec![]
147        }
148    }
149
150    fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
151        if self.affine {
152            vec![&mut self.weight, &mut self.bias]
153        } else {
154            vec![]
155        }
156    }
157}
158
159/// Root Mean Square Layer Normalization (Zhang & Sennrich, 2019).
160///
161/// A simplified version of `LayerNorm` that only uses the root mean square
162/// for normalization, without centering (no mean subtraction).
163/// This is faster than `LayerNorm` while achieving similar results.
164///
165/// ```text
166/// y = x / RMS(x) * gamma
167/// RMS(x) = sqrt(mean(x^2) + eps)
168/// ```
169///
170/// Used in `LLaMA`, Gemma, and other modern transformers.
171///
172/// # Example
173///
174/// ```ignore
175/// use aprender::nn::{RMSNorm, Module};
176/// use aprender::autograd::Tensor;
177///
178/// let norm = RMSNorm::new(&[256]);  // Normalize over 256 features
179/// let x = Tensor::randn(&[32, 10, 256]);  // [batch, seq, features]
180/// let y = norm.forward(&x);  // Normalized
181/// ```
182///
183/// # References
184///
185/// - Zhang, B., & Sennrich, R. (2019). Root Mean Square Layer Normalization.
186///   `NeurIPS`.
187#[derive(Debug)]
188pub struct RMSNorm {
189    /// Shape of the normalized dimensions
190    normalized_shape: Vec<usize>,
191    /// Small constant for numerical stability
192    eps: f32,
193    /// Learnable scale parameter (gamma)
194    weight: Tensor,
195    /// Whether to use learnable scale parameter
196    elementwise_affine: bool,
197}
198
199impl RMSNorm {
200    /// Create a new `RMSNorm` layer.
201    ///
202    /// # Arguments
203    ///
204    /// * `normalized_shape` - Shape of the dimensions to normalize over
205    #[must_use]
206    pub fn new(normalized_shape: &[usize]) -> Self {
207        let numel: usize = normalized_shape.iter().product();
208        Self {
209            normalized_shape: normalized_shape.to_vec(),
210            eps: 1e-6, // Smaller default eps than LayerNorm (common in LLMs)
211            weight: constant(&[numel], 1.0).requires_grad(),
212            elementwise_affine: true,
213        }
214    }
215
216    /// Create `RMSNorm` with custom epsilon.
217    #[must_use]
218    pub fn with_eps(normalized_shape: &[usize], eps: f32) -> Self {
219        let mut layer = Self::new(normalized_shape);
220        layer.eps = eps;
221        layer
222    }
223
224    /// Create `RMSNorm` without learnable parameters.
225    #[must_use]
226    pub fn without_affine(normalized_shape: &[usize]) -> Self {
227        let numel: usize = normalized_shape.iter().product();
228        Self {
229            normalized_shape: normalized_shape.to_vec(),
230            eps: 1e-6,
231            weight: constant(&[numel], 1.0),
232            elementwise_affine: false,
233        }
234    }
235
236    /// Get the normalized shape.
237    #[must_use]
238    pub fn normalized_shape(&self) -> &[usize] {
239        &self.normalized_shape
240    }
241
242    /// Get the epsilon value.
243    #[must_use]
244    pub fn eps(&self) -> f32 {
245        self.eps
246    }
247
248    /// Set weight tensor from external data.
249    ///
250    /// Used for loading pre-trained weights.
251    pub fn set_weight(&mut self, weight: Tensor) {
252        self.weight = weight;
253    }
254
255    /// Get reference to weight tensor.
256    #[must_use]
257    pub fn weight(&self) -> &Tensor {
258        &self.weight
259    }
260
261    /// Create a placeholder `RMSNorm` layer with minimal memory allocation.
262    ///
263    /// Used for lazy initialization when loading pre-trained weights.
264    /// The placeholder uses 1-element tensors instead of full vectors,
265    /// reducing memory from O(n) to O(1).
266    ///
267    /// **IMPORTANT**: This layer will NOT work for inference until
268    /// `set_weight()` is called with real weights.
269    #[must_use]
270    pub fn placeholder(normalized_shape: &[usize]) -> Self {
271        Self {
272            normalized_shape: normalized_shape.to_vec(),
273            eps: 1e-6,
274            weight: Tensor::new(&[1.0], &[1]),
275            elementwise_affine: true,
276        }
277    }
278}
279
280impl Module for RMSNorm {
281    /// ONE PATH: Delegates computation to `nn::functional::rms_norm` (UCBD ยง4).
282    /// Shape validation and non-affine path handled here (Module layer).
283    #[provable_contracts_macros::contract("rmsnorm-kernel-v1", equation = "rmsnorm")]
284    fn forward(&self, input: &Tensor) -> Tensor {
285        let shape = input.shape();
286        let norm_size: usize = self.normalized_shape.iter().product();
287
288        // Check dimensions
289        assert!(
290            shape.len() >= self.normalized_shape.len(),
291            "Input must have at least as many dimensions as normalized_shape"
292        );
293
294        // Check that the last dimensions match
295        let start_dim = shape.len() - self.normalized_shape.len();
296        for (i, &ns) in self.normalized_shape.iter().enumerate() {
297            assert_eq!(
298                shape[start_dim + i],
299                ns,
300                "Input shape doesn't match normalized_shape at dim {i}"
301            );
302        }
303
304        if self.elementwise_affine {
305            // ONE PATH: delegate to canonical functional rms_norm
306            crate::nn::functional::rms_norm(input, &self.weight, self.eps)
307        } else {
308            // Non-affine: normalize without weight
309            let batch_dims: usize = shape[..start_dim].iter().product();
310            let input_data = input.data();
311            let mut output_data = vec![0.0; input_data.len()];
312
313            for b in 0..batch_dims {
314                let offset = b * norm_size;
315                let slice = &input_data[offset..offset + norm_size];
316
317                let mean_sq: f32 = slice.iter().map(|&x| x * x).sum::<f32>() / norm_size as f32;
318                let rms_inv = 1.0 / (mean_sq + self.eps).sqrt();
319
320                for i in 0..norm_size {
321                    output_data[offset + i] = slice[i] * rms_inv;
322                }
323            }
324
325            Tensor::new(&output_data, shape)
326        }
327    }
328
329    fn parameters(&self) -> Vec<&Tensor> {
330        if self.elementwise_affine {
331            vec![&self.weight]
332        } else {
333            vec![]
334        }
335    }
336
337    fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
338        if self.elementwise_affine {
339            vec![&mut self.weight]
340        } else {
341            vec![]
342        }
343    }
344}
345
346/// Instance Normalization.
347///
348/// Normalizes each channel independently for each sample.
349/// Commonly used in style transfer networks.
350///
351/// This is equivalent to `GroupNorm` with `num_groups` = `num_channels`.
352#[derive(Debug)]
353pub struct InstanceNorm {
354    inner: GroupNorm,
355}
356
357impl InstanceNorm {
358    /// Create a new `InstanceNorm` layer.
359    #[must_use]
360    pub fn new(num_channels: usize) -> Self {
361        Self {
362            inner: GroupNorm::new(num_channels, num_channels),
363        }
364    }
365
366    /// Create `InstanceNorm` without learnable parameters.
367    #[must_use]
368    pub fn without_affine(num_channels: usize) -> Self {
369        Self {
370            inner: GroupNorm::without_affine(num_channels, num_channels),
371        }
372    }
373}
374
375impl Module for InstanceNorm {
376    fn forward(&self, input: &Tensor) -> Tensor {
377        self.inner.forward(input)
378    }
379
380    fn parameters(&self) -> Vec<&Tensor> {
381        self.inner.parameters()
382    }
383
384    fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
385        self.inner.parameters_mut()
386    }
387}
388
389#[cfg(test)]
390#[path = "tests.rs"]
391mod tests;