Skip to main content

axonml_nn/layers/
norm.rs

1//! Normalization Layers - BatchNorm and LayerNorm
2//!
3//! Normalizes inputs to improve training stability and speed.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicBool, Ordering};
10
11use axonml_autograd::Variable;
12use axonml_tensor::Tensor;
13use parking_lot::RwLock;
14
15use crate::init::{ones, zeros};
16use crate::module::Module;
17use crate::parameter::Parameter;
18
19// =============================================================================
20// BatchNorm1d
21// =============================================================================
22
23/// Applies Batch Normalization over a 2D or 3D input.
24///
25/// y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
26///
27/// # Shape
28/// - Input: (N, C) or (N, C, L)
29/// - Output: Same as input
30pub struct BatchNorm1d {
31    /// Learnable scale parameter (gamma).
32    pub weight: Parameter,
33    /// Learnable shift parameter (beta).
34    pub bias: Parameter,
35    /// Running mean for inference (updated during training).
36    running_mean: RwLock<Tensor<f32>>,
37    /// Running variance for inference (updated during training).
38    running_var: RwLock<Tensor<f32>>,
39    /// Number of features.
40    num_features: usize,
41    /// Epsilon for numerical stability.
42    eps: f32,
43    /// Momentum for running stats update: running = (1 - momentum) * running + momentum * batch.
44    momentum: f32,
45    /// Whether to track running stats.
46    track_running_stats: bool,
47    /// Whether in training mode.
48    training: AtomicBool,
49}
50
51impl BatchNorm1d {
52    /// Creates a new BatchNorm1d layer.
53    pub fn new(num_features: usize) -> Self {
54        Self::with_options(num_features, 1e-5, 0.1, true)
55    }
56
57    /// Creates a BatchNorm1d with custom options.
58    pub fn with_options(
59        num_features: usize,
60        eps: f32,
61        momentum: f32,
62        track_running_stats: bool,
63    ) -> Self {
64        Self {
65            weight: Parameter::named("weight", ones(&[num_features]), true),
66            bias: Parameter::named("bias", zeros(&[num_features]), true),
67            running_mean: RwLock::new(zeros(&[num_features])),
68            running_var: RwLock::new(ones(&[num_features])),
69            num_features,
70            eps,
71            momentum,
72            track_running_stats,
73            training: AtomicBool::new(true),
74        }
75    }
76
77    /// Returns the number of features.
78    pub fn num_features(&self) -> usize {
79        self.num_features
80    }
81}
82
83impl Module for BatchNorm1d {
84    fn forward(&self, input: &Variable) -> Variable {
85        let input_data = input.data();
86        let shape = input_data.shape().to_vec();
87        let batch_size = shape[0];
88        let num_features = shape[1];
89
90        // Validate input matches expected features
91        assert_eq!(
92            num_features, self.num_features,
93            "BatchNorm1d: expected {} features, got {}",
94            self.num_features, num_features
95        );
96
97        let input_vec = input_data.to_vec();
98        let weight_vec = self.weight.data().to_vec();
99        let bias_vec = self.bias.data().to_vec();
100
101        let is_training = self.training.load(Ordering::Relaxed);
102        let spatial_size: usize = if shape.len() > 2 {
103            shape[2..].iter().product()
104        } else {
105            1
106        };
107
108        let mut means = vec![0.0f32; num_features];
109        let mut vars = vec![0.0f32; num_features];
110
111        if is_training {
112            // Calculate batch statistics
113            for c in 0..num_features {
114                let mut sum = 0.0f32;
115                for b in 0..batch_size {
116                    for s in 0..spatial_size {
117                        let idx = b * num_features * spatial_size + c * spatial_size + s;
118                        sum += input_vec[idx];
119                    }
120                }
121                means[c] = sum / (batch_size * spatial_size) as f32;
122
123                let mut var_sum = 0.0f32;
124                for b in 0..batch_size {
125                    for s in 0..spatial_size {
126                        let idx = b * num_features * spatial_size + c * spatial_size + s;
127                        let diff = input_vec[idx] - means[c];
128                        var_sum += diff * diff;
129                    }
130                }
131                vars[c] = var_sum / (batch_size * spatial_size) as f32;
132            }
133
134            // Update running statistics if tracking is enabled
135            if self.track_running_stats {
136                let mut running_mean = self.running_mean.write();
137                let mut running_var = self.running_var.write();
138                let running_mean_vec = running_mean.to_vec();
139                let running_var_vec = running_var.to_vec();
140
141                let new_mean: Vec<f32> = running_mean_vec
142                    .iter()
143                    .zip(means.iter())
144                    .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
145                    .collect();
146                let new_var: Vec<f32> = running_var_vec
147                    .iter()
148                    .zip(vars.iter())
149                    .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
150                    .collect();
151
152                *running_mean = Tensor::from_vec(new_mean, &[num_features]).unwrap();
153                *running_var = Tensor::from_vec(new_var, &[num_features]).unwrap();
154            }
155        } else {
156            // Use running statistics for inference
157            means = self.running_mean.read().to_vec();
158            vars = self.running_var.read().to_vec();
159        }
160
161        // Normalize: y = (x - mean) / sqrt(var + eps) * weight + bias
162        let mut output_vec = vec![0.0f32; input_vec.len()];
163        for b in 0..batch_size {
164            for c in 0..num_features {
165                for s in 0..spatial_size {
166                    let idx = b * num_features * spatial_size + c * spatial_size + s;
167                    let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
168                    output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
169                }
170            }
171        }
172
173        let output = Tensor::from_vec(output_vec, &shape).unwrap();
174        Variable::new(output, input.requires_grad())
175    }
176
177    fn parameters(&self) -> Vec<Parameter> {
178        vec![self.weight.clone(), self.bias.clone()]
179    }
180
181    fn named_parameters(&self) -> HashMap<String, Parameter> {
182        let mut params = HashMap::new();
183        params.insert("weight".to_string(), self.weight.clone());
184        params.insert("bias".to_string(), self.bias.clone());
185        params
186    }
187
188    fn set_training(&mut self, training: bool) {
189        self.training.store(training, Ordering::Relaxed);
190    }
191
192    fn is_training(&self) -> bool {
193        self.training.load(Ordering::Relaxed)
194    }
195
196    fn name(&self) -> &'static str {
197        "BatchNorm1d"
198    }
199}
200
201// =============================================================================
202// BatchNorm2d
203// =============================================================================
204
205/// Applies Batch Normalization over a 4D input (images).
206///
207/// # Shape
208/// - Input: (N, C, H, W)
209/// - Output: Same as input
210pub struct BatchNorm2d {
211    /// Learnable scale parameter (gamma).
212    pub weight: Parameter,
213    /// Learnable shift parameter (beta).
214    pub bias: Parameter,
215    /// Running mean for inference (updated during training).
216    running_mean: RwLock<Tensor<f32>>,
217    /// Running variance for inference (updated during training).
218    running_var: RwLock<Tensor<f32>>,
219    /// Number of features (channels).
220    num_features: usize,
221    /// Epsilon for numerical stability.
222    eps: f32,
223    /// Momentum for running stats update.
224    momentum: f32,
225    /// Whether in training mode.
226    training: AtomicBool,
227}
228
229impl BatchNorm2d {
230    /// Creates a new BatchNorm2d layer.
231    pub fn new(num_features: usize) -> Self {
232        Self::with_options(num_features, 1e-5, 0.1)
233    }
234
235    /// Creates a BatchNorm2d with custom options.
236    pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
237        Self {
238            weight: Parameter::named("weight", ones(&[num_features]), true),
239            bias: Parameter::named("bias", zeros(&[num_features]), true),
240            running_mean: RwLock::new(zeros(&[num_features])),
241            running_var: RwLock::new(ones(&[num_features])),
242            num_features,
243            eps,
244            momentum,
245            training: AtomicBool::new(true),
246        }
247    }
248
249    /// Returns the number of features (channels).
250    pub fn num_features(&self) -> usize {
251        self.num_features
252    }
253}
254
255impl Module for BatchNorm2d {
256    fn forward(&self, input: &Variable) -> Variable {
257        let input_data = input.data();
258        let shape = input_data.shape().to_vec();
259        let batch_size = shape[0];
260        let channels = shape[1];
261        let height = shape[2];
262        let width = shape[3];
263        let spatial_size = height * width;
264
265        // Validate input matches expected channels
266        assert_eq!(
267            channels, self.num_features,
268            "BatchNorm2d: expected {} channels, got {}",
269            self.num_features, channels
270        );
271
272        let input_vec = input_data.to_vec();
273        let weight_vec = self.weight.data().to_vec();
274        let bias_vec = self.bias.data().to_vec();
275
276        let is_training = self.training.load(Ordering::Relaxed);
277
278        let mut means = vec![0.0f32; channels];
279        let mut vars = vec![0.0f32; channels];
280
281        if is_training {
282            for c in 0..channels {
283                let mut sum = 0.0f32;
284                for b in 0..batch_size {
285                    for h in 0..height {
286                        for w in 0..width {
287                            let idx =
288                                b * channels * spatial_size + c * spatial_size + h * width + w;
289                            sum += input_vec[idx];
290                        }
291                    }
292                }
293                means[c] = sum / (batch_size * spatial_size) as f32;
294
295                let mut var_sum = 0.0f32;
296                for b in 0..batch_size {
297                    for h in 0..height {
298                        for w in 0..width {
299                            let idx =
300                                b * channels * spatial_size + c * spatial_size + h * width + w;
301                            let diff = input_vec[idx] - means[c];
302                            var_sum += diff * diff;
303                        }
304                    }
305                }
306                vars[c] = var_sum / (batch_size * spatial_size) as f32;
307            }
308
309            // Update running statistics
310            let mut running_mean = self.running_mean.write();
311            let mut running_var = self.running_var.write();
312            let running_mean_vec = running_mean.to_vec();
313            let running_var_vec = running_var.to_vec();
314
315            let new_mean: Vec<f32> = running_mean_vec
316                .iter()
317                .zip(means.iter())
318                .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
319                .collect();
320            let new_var: Vec<f32> = running_var_vec
321                .iter()
322                .zip(vars.iter())
323                .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
324                .collect();
325
326            *running_mean = Tensor::from_vec(new_mean, &[channels]).unwrap();
327            *running_var = Tensor::from_vec(new_var, &[channels]).unwrap();
328        } else {
329            means = self.running_mean.read().to_vec();
330            vars = self.running_var.read().to_vec();
331        }
332
333        let mut output_vec = vec![0.0f32; input_vec.len()];
334        for b in 0..batch_size {
335            for c in 0..channels {
336                for h in 0..height {
337                    for w in 0..width {
338                        let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
339                        let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
340                        output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
341                    }
342                }
343            }
344        }
345
346        let output = Tensor::from_vec(output_vec, &shape).unwrap();
347        Variable::new(output, input.requires_grad())
348    }
349
350    fn parameters(&self) -> Vec<Parameter> {
351        vec![self.weight.clone(), self.bias.clone()]
352    }
353
354    fn named_parameters(&self) -> HashMap<String, Parameter> {
355        let mut params = HashMap::new();
356        params.insert("weight".to_string(), self.weight.clone());
357        params.insert("bias".to_string(), self.bias.clone());
358        params
359    }
360
361    fn set_training(&mut self, training: bool) {
362        self.training.store(training, Ordering::Relaxed);
363    }
364
365    fn is_training(&self) -> bool {
366        self.training.load(Ordering::Relaxed)
367    }
368
369    fn name(&self) -> &'static str {
370        "BatchNorm2d"
371    }
372}
373
374// =============================================================================
375// LayerNorm
376// =============================================================================
377
378/// Applies Layer Normalization over the last D dimensions.
379///
380/// y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
381///
382/// Unlike BatchNorm, LayerNorm normalizes over features, not batch.
383pub struct LayerNorm {
384    /// Learnable scale parameter (gamma).
385    pub weight: Parameter,
386    /// Learnable shift parameter (beta).
387    pub bias: Parameter,
388    /// Normalized shape.
389    normalized_shape: Vec<usize>,
390    /// Epsilon for numerical stability.
391    eps: f32,
392}
393
394impl LayerNorm {
395    /// Creates a new LayerNorm layer.
396    pub fn new(normalized_shape: Vec<usize>) -> Self {
397        Self::with_eps(normalized_shape, 1e-5)
398    }
399
400    /// Creates a LayerNorm for a single dimension.
401    pub fn single(size: usize) -> Self {
402        Self::new(vec![size])
403    }
404
405    /// Creates a LayerNorm with custom epsilon.
406    pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
407        let numel: usize = normalized_shape.iter().product();
408        Self {
409            weight: Parameter::named("weight", ones(&[numel]), true),
410            bias: Parameter::named("bias", zeros(&[numel]), true),
411            normalized_shape,
412            eps,
413        }
414    }
415}
416
417impl Module for LayerNorm {
418    fn forward(&self, input: &Variable) -> Variable {
419        let input_data = input.data();
420        let shape = input_data.shape().to_vec();
421        let input_vec = input_data.to_vec();
422
423        let weight_vec = self.weight.data().to_vec();
424        let bias_vec = self.bias.data().to_vec();
425
426        // Calculate the size of the normalized dimensions
427        let norm_size: usize = self.normalized_shape.iter().product();
428        let batch_size = input_vec.len() / norm_size;
429
430        let mut output_vec = vec![0.0f32; input_vec.len()];
431
432        for b in 0..batch_size {
433            let start = b * norm_size;
434            let end = start + norm_size;
435            let slice = &input_vec[start..end];
436
437            // Calculate mean
438            let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
439
440            // Calculate variance
441            let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
442
443            // Normalize and apply affine transform
444            for i in 0..norm_size {
445                let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
446                output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
447            }
448        }
449
450        let output = Tensor::from_vec(output_vec, &shape).unwrap();
451        Variable::new(output, input.requires_grad())
452    }
453
454    fn parameters(&self) -> Vec<Parameter> {
455        vec![self.weight.clone(), self.bias.clone()]
456    }
457
458    fn named_parameters(&self) -> HashMap<String, Parameter> {
459        let mut params = HashMap::new();
460        params.insert("weight".to_string(), self.weight.clone());
461        params.insert("bias".to_string(), self.bias.clone());
462        params
463    }
464
465    fn name(&self) -> &'static str {
466        "LayerNorm"
467    }
468}
469
470// =============================================================================
471// Tests
472// =============================================================================
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477
478    #[test]
479    fn test_batchnorm1d() {
480        let bn = BatchNorm1d::new(3);
481        let input = Variable::new(
482            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
483            false,
484        );
485        let output = bn.forward(&input);
486        assert_eq!(output.shape(), vec![2, 3]);
487    }
488
489    #[test]
490    fn test_batchnorm2d() {
491        let bn = BatchNorm2d::new(2);
492        let input = Variable::new(
493            Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
494            false,
495        );
496        let output = bn.forward(&input);
497        assert_eq!(output.shape(), vec![2, 2, 2, 4]);
498    }
499
500    #[test]
501    fn test_layernorm() {
502        let ln = LayerNorm::single(4);
503        let input = Variable::new(
504            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap(),
505            false,
506        );
507        let output = ln.forward(&input);
508        assert_eq!(output.shape(), vec![2, 4]);
509    }
510
511    #[test]
512    fn test_batchnorm_parameters() {
513        let bn = BatchNorm1d::new(10);
514        assert_eq!(bn.parameters().len(), 2);
515        assert_eq!(bn.num_parameters(), 20); // weight + bias
516    }
517}