Skip to main content

axonml_nn/layers/
norm.rs

1//! Normalization Layers
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/norm.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::Variable;
21use axonml_autograd::functions::{
22    BatchNorm1dBackward, BatchNorm2dBackward, GroupNormBackward, InstanceNorm2dBackward,
23    LayerNormBackward,
24};
25use axonml_autograd::grad_fn::GradFn;
26use axonml_autograd::no_grad::is_grad_enabled;
27use axonml_tensor::Tensor;
28use parking_lot::RwLock;
29
30use crate::init::{ones, zeros};
31use crate::module::Module;
32use crate::parameter::Parameter;
33
34// =============================================================================
35// BatchNorm1d
36// =============================================================================
37
38/// Applies Batch Normalization over a 2D or 3D input.
39///
40/// y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
41///
42/// # Shape
43/// - Input: (N, C) or (N, C, L)
44/// - Output: Same as input
45pub struct BatchNorm1d {
46    /// Learnable scale parameter (gamma).
47    pub weight: Parameter,
48    /// Learnable shift parameter (beta).
49    pub bias: Parameter,
50    /// Running mean for inference (updated during training).
51    running_mean: RwLock<Tensor<f32>>,
52    /// Running variance for inference (updated during training).
53    running_var: RwLock<Tensor<f32>>,
54    /// Number of features.
55    num_features: usize,
56    /// Epsilon for numerical stability.
57    eps: f32,
58    /// Momentum for running stats update: running = (1 - momentum) * running + momentum * batch.
59    momentum: f32,
60    /// Whether to track running stats.
61    track_running_stats: bool,
62    /// Whether in training mode.
63    training: AtomicBool,
64}
65
66impl BatchNorm1d {
67    /// Creates a new BatchNorm1d layer.
68    pub fn new(num_features: usize) -> Self {
69        Self::with_options(num_features, 1e-5, 0.1, true)
70    }
71
72    /// Creates a BatchNorm1d with custom options.
73    pub fn with_options(
74        num_features: usize,
75        eps: f32,
76        momentum: f32,
77        track_running_stats: bool,
78    ) -> Self {
79        Self {
80            weight: Parameter::named("weight", ones(&[num_features]), true),
81            bias: Parameter::named("bias", zeros(&[num_features]), true),
82            running_mean: RwLock::new(zeros(&[num_features])),
83            running_var: RwLock::new(ones(&[num_features])),
84            num_features,
85            eps,
86            momentum,
87            track_running_stats,
88            training: AtomicBool::new(true),
89        }
90    }
91
92    /// Returns the number of features.
93    pub fn num_features(&self) -> usize {
94        self.num_features
95    }
96}
97
98impl Module for BatchNorm1d {
99    fn forward(&self, input: &Variable) -> Variable {
100        let input_data = input.data();
101        let shape = input_data.shape().to_vec();
102        let batch_size = shape[0];
103        let num_features = shape[1];
104
105        // Validate input matches expected features
106        assert_eq!(
107            num_features, self.num_features,
108            "BatchNorm1d: expected {} features, got {}",
109            self.num_features, num_features
110        );
111
112        let is_training = self.training.load(Ordering::Relaxed);
113        let spatial_size: usize = if shape.len() > 2 {
114            shape[2..].iter().product()
115        } else {
116            1
117        };
118
119        // GPU fast path: use fused batchnorm kernels when input is on GPU.
120        // For [batch, features] layout, spatial=1. The kernel indexes as
121        // (idx / spatial) % C which with spatial=1 becomes idx % C — correct
122        // for [batch, features] since it's the same layout as [batch, features, 1].
123        #[cfg(feature = "cuda")]
124        if input_data.device().is_gpu() && is_training {
125            let gamma_data = self.weight.data();
126            let beta_data = self.bias.data();
127
128            // Auto-migrate weight/bias to GPU if needed
129            let gamma_gpu = if !gamma_data.device().is_gpu() {
130                gamma_data
131                    .to_device(input_data.device())
132                    .unwrap_or(gamma_data)
133            } else {
134                gamma_data
135            };
136            let beta_gpu = if !beta_data.device().is_gpu() {
137                beta_data
138                    .to_device(input_data.device())
139                    .unwrap_or(beta_data)
140            } else {
141                beta_data
142            };
143
144            if let Some((output_tensor, means, vars)) = input_data.batchnorm_fused(
145                &gamma_gpu,
146                &beta_gpu,
147                self.eps,
148                num_features,
149                spatial_size,
150            ) {
151                // Update running statistics
152                if self.track_running_stats {
153                    let mut running_mean = self.running_mean.write();
154                    let mut running_var = self.running_var.write();
155                    let running_mean_vec = running_mean.to_vec();
156                    let running_var_vec = running_var.to_vec();
157                    let new_mean: Vec<f32> = running_mean_vec
158                        .iter()
159                        .zip(means.iter())
160                        .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
161                        .collect();
162                    let new_var: Vec<f32> = running_var_vec
163                        .iter()
164                        .zip(vars.iter())
165                        .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
166                        .collect();
167                    *running_mean = Tensor::from_vec(new_mean, &[num_features]).unwrap();
168                    *running_var = Tensor::from_vec(new_var, &[num_features]).unwrap();
169                }
170
171                let weight_vec = gamma_gpu.to_vec();
172                let requires_grad =
173                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
174                if requires_grad {
175                    let weight_var = self.weight.variable();
176                    let bias_var = self.bias.variable();
177                    let grad_fn = GradFn::new(BatchNorm1dBackward::new(
178                        input.grad_fn().cloned(),
179                        weight_var.grad_fn().cloned(),
180                        bias_var.grad_fn().cloned(),
181                        input_data,
182                        means,
183                        vars,
184                        weight_vec,
185                        self.eps,
186                        self.num_features,
187                    ));
188                    return Variable::from_operation(output_tensor, grad_fn, true);
189                }
190                return Variable::new(output_tensor, false);
191            }
192        }
193
194        let input_vec = input_data.to_vec();
195        let weight_vec = self.weight.data().to_vec();
196        let bias_vec = self.bias.data().to_vec();
197
198        let mut means = vec![0.0f32; num_features];
199        let mut vars = vec![0.0f32; num_features];
200
201        if is_training {
202            // Calculate batch statistics
203            for c in 0..num_features {
204                let mut sum = 0.0f32;
205                for b in 0..batch_size {
206                    for s in 0..spatial_size {
207                        let idx = b * num_features * spatial_size + c * spatial_size + s;
208                        sum += input_vec[idx];
209                    }
210                }
211                means[c] = sum / (batch_size * spatial_size) as f32;
212
213                let mut var_sum = 0.0f32;
214                for b in 0..batch_size {
215                    for s in 0..spatial_size {
216                        let idx = b * num_features * spatial_size + c * spatial_size + s;
217                        let diff = input_vec[idx] - means[c];
218                        var_sum += diff * diff;
219                    }
220                }
221                vars[c] = var_sum / (batch_size * spatial_size) as f32;
222            }
223
224            // Update running statistics if tracking is enabled
225            if self.track_running_stats {
226                let mut running_mean = self.running_mean.write();
227                let mut running_var = self.running_var.write();
228                let running_mean_vec = running_mean.to_vec();
229                let running_var_vec = running_var.to_vec();
230
231                let new_mean: Vec<f32> = running_mean_vec
232                    .iter()
233                    .zip(means.iter())
234                    .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
235                    .collect();
236                let new_var: Vec<f32> = running_var_vec
237                    .iter()
238                    .zip(vars.iter())
239                    .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
240                    .collect();
241
242                *running_mean = Tensor::from_vec(new_mean, &[num_features]).unwrap();
243                *running_var = Tensor::from_vec(new_var, &[num_features]).unwrap();
244            }
245        } else {
246            // Use running statistics for inference
247            means = self.running_mean.read().to_vec();
248            vars = self.running_var.read().to_vec();
249        }
250
251        // Normalize: y = (x - mean) / sqrt(var + eps) * weight + bias
252        let mut output_vec = vec![0.0f32; input_vec.len()];
253        for b in 0..batch_size {
254            for c in 0..num_features {
255                for s in 0..spatial_size {
256                    let idx = b * num_features * spatial_size + c * spatial_size + s;
257                    let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
258                    output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
259                }
260            }
261        }
262
263        let output = Tensor::from_vec(output_vec, &shape).unwrap();
264
265        let requires_grad =
266            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
267        if requires_grad {
268            let weight_var = self.weight.variable();
269            let bias_var = self.bias.variable();
270
271            let grad_fn = GradFn::new(BatchNorm1dBackward::new(
272                input.grad_fn().cloned(),
273                weight_var.grad_fn().cloned(),
274                bias_var.grad_fn().cloned(),
275                input_data,
276                means.clone(),
277                vars.clone(),
278                weight_vec,
279                self.eps,
280                self.num_features,
281            ));
282            Variable::from_operation(output, grad_fn, true)
283        } else {
284            Variable::new(output, false)
285        }
286    }
287
288    fn parameters(&self) -> Vec<Parameter> {
289        vec![self.weight.clone(), self.bias.clone()]
290    }
291
292    fn named_parameters(&self) -> HashMap<String, Parameter> {
293        let mut params = HashMap::new();
294        params.insert("weight".to_string(), self.weight.clone());
295        params.insert("bias".to_string(), self.bias.clone());
296        params
297    }
298
299    fn set_training(&mut self, training: bool) {
300        self.training.store(training, Ordering::Relaxed);
301    }
302
303    fn is_training(&self) -> bool {
304        self.training.load(Ordering::Relaxed)
305    }
306
307    fn name(&self) -> &'static str {
308        "BatchNorm1d"
309    }
310}
311
312// =============================================================================
313// BatchNorm2d
314// =============================================================================
315
316/// Applies Batch Normalization over a 4D input (images).
317///
318/// # Shape
319/// - Input: (N, C, H, W)
320/// - Output: Same as input
321pub struct BatchNorm2d {
322    /// Learnable scale parameter (gamma).
323    pub weight: Parameter,
324    /// Learnable shift parameter (beta).
325    pub bias: Parameter,
326    /// Running mean for inference (updated during training).
327    running_mean: RwLock<Tensor<f32>>,
328    /// Running variance for inference (updated during training).
329    running_var: RwLock<Tensor<f32>>,
330    /// Number of features (channels).
331    num_features: usize,
332    /// Epsilon for numerical stability.
333    eps: f32,
334    /// Momentum for running stats update.
335    momentum: f32,
336    /// Whether in training mode.
337    training: AtomicBool,
338}
339
340impl BatchNorm2d {
341    /// Creates a new BatchNorm2d layer.
342    pub fn new(num_features: usize) -> Self {
343        Self::with_options(num_features, 1e-5, 0.1)
344    }
345
346    /// Creates a BatchNorm2d with custom options.
347    pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
348        Self {
349            weight: Parameter::named("weight", ones(&[num_features]), true),
350            bias: Parameter::named("bias", zeros(&[num_features]), true),
351            running_mean: RwLock::new(zeros(&[num_features])),
352            running_var: RwLock::new(ones(&[num_features])),
353            num_features,
354            eps,
355            momentum,
356            training: AtomicBool::new(true),
357        }
358    }
359
360    /// Returns the number of features (channels).
361    pub fn num_features(&self) -> usize {
362        self.num_features
363    }
364}
365
366impl Module for BatchNorm2d {
367    fn forward(&self, input: &Variable) -> Variable {
368        let input_data = input.data();
369        let shape = input_data.shape().to_vec();
370        let batch_size = shape[0];
371        let channels = shape[1];
372        let height = shape[2];
373        let width = shape[3];
374        let spatial_size = height * width;
375
376        // Validate input matches expected channels
377        assert_eq!(
378            channels, self.num_features,
379            "BatchNorm2d: expected {} channels, got {}",
380            self.num_features, channels
381        );
382
383        let is_training = self.training.load(Ordering::Relaxed);
384
385        // GPU fast path: use fused batchnorm kernels when input is on GPU
386        #[cfg(feature = "cuda")]
387        if input_data.device().is_gpu() && is_training {
388            let gamma_data = self.weight.data();
389            let beta_data = self.bias.data();
390
391            // Auto-migrate weight/bias to GPU if needed
392            let gamma_gpu = if !gamma_data.device().is_gpu() {
393                gamma_data
394                    .to_device(input_data.device())
395                    .unwrap_or(gamma_data)
396            } else {
397                gamma_data
398            };
399            let beta_gpu = if !beta_data.device().is_gpu() {
400                beta_data
401                    .to_device(input_data.device())
402                    .unwrap_or(beta_data)
403            } else {
404                beta_data
405            };
406
407            if let Some((output_tensor, means, vars)) =
408                input_data.batchnorm_fused(&gamma_gpu, &beta_gpu, self.eps, channels, spatial_size)
409            {
410                // Update running statistics
411                let mut running_mean = self.running_mean.write();
412                let mut running_var = self.running_var.write();
413                let running_mean_vec = running_mean.to_vec();
414                let running_var_vec = running_var.to_vec();
415                let new_mean: Vec<f32> = running_mean_vec
416                    .iter()
417                    .zip(means.iter())
418                    .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
419                    .collect();
420                let new_var: Vec<f32> = running_var_vec
421                    .iter()
422                    .zip(vars.iter())
423                    .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
424                    .collect();
425                *running_mean = Tensor::from_vec(new_mean, &[channels]).unwrap();
426                *running_var = Tensor::from_vec(new_var, &[channels]).unwrap();
427
428                let weight_vec = gamma_gpu.to_vec();
429                let requires_grad =
430                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
431                if requires_grad {
432                    let weight_var = self.weight.variable();
433                    let bias_var = self.bias.variable();
434                    let grad_fn = GradFn::new(BatchNorm2dBackward::new(
435                        input.grad_fn().cloned(),
436                        weight_var.grad_fn().cloned(),
437                        bias_var.grad_fn().cloned(),
438                        input_data,
439                        means,
440                        vars,
441                        weight_vec,
442                        self.eps,
443                        self.num_features,
444                    ));
445                    return Variable::from_operation(output_tensor, grad_fn, true);
446                }
447                return Variable::new(output_tensor, false);
448            }
449        }
450
451        // CPU path
452        let input_vec = input_data.to_vec();
453        let weight_vec = self.weight.data().to_vec();
454        let bias_vec = self.bias.data().to_vec();
455
456        let mut means = vec![0.0f32; channels];
457        let mut vars = vec![0.0f32; channels];
458
459        if is_training {
460            let n_per_channel = (batch_size * spatial_size) as f32;
461            for c in 0..channels {
462                let mut sum = 0.0f32;
463                let mut sum_sq = 0.0f32;
464                for b in 0..batch_size {
465                    let base = b * channels * spatial_size + c * spatial_size;
466                    for s in 0..spatial_size {
467                        let val = input_vec[base + s];
468                        sum += val;
469                        sum_sq += val * val;
470                    }
471                }
472                means[c] = sum / n_per_channel;
473                vars[c] = sum_sq / n_per_channel - means[c] * means[c];
474            }
475
476            // Update running statistics
477            let mut running_mean = self.running_mean.write();
478            let mut running_var = self.running_var.write();
479            let running_mean_vec = running_mean.to_vec();
480            let running_var_vec = running_var.to_vec();
481
482            let new_mean: Vec<f32> = running_mean_vec
483                .iter()
484                .zip(means.iter())
485                .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
486                .collect();
487            let new_var: Vec<f32> = running_var_vec
488                .iter()
489                .zip(vars.iter())
490                .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
491                .collect();
492
493            *running_mean = Tensor::from_vec(new_mean, &[channels]).unwrap();
494            *running_var = Tensor::from_vec(new_var, &[channels]).unwrap();
495        } else {
496            means = self.running_mean.read().to_vec();
497            vars = self.running_var.read().to_vec();
498        }
499
500        // Normalize + affine transform (optimized single-pass)
501        let total = input_vec.len();
502        let mut output_vec = vec![0.0f32; total];
503
504        // Pre-compute inv_std per channel to avoid repeated sqrt
505        let inv_stds: Vec<f32> = vars.iter().map(|v| 1.0 / (v + self.eps).sqrt()).collect();
506
507        for i in 0..total {
508            let c = (i / spatial_size) % channels;
509            output_vec[i] = (input_vec[i] - means[c]) * inv_stds[c] * weight_vec[c] + bias_vec[c];
510        }
511
512        let output = Tensor::from_vec(output_vec, &shape).unwrap();
513
514        let requires_grad =
515            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
516        if requires_grad {
517            let weight_var = self.weight.variable();
518            let bias_var = self.bias.variable();
519
520            let grad_fn = GradFn::new(BatchNorm2dBackward::new(
521                input.grad_fn().cloned(),
522                weight_var.grad_fn().cloned(),
523                bias_var.grad_fn().cloned(),
524                input_data,
525                means.clone(),
526                vars.clone(),
527                weight_vec,
528                self.eps,
529                self.num_features,
530            ));
531            Variable::from_operation(output, grad_fn, true)
532        } else {
533            Variable::new(output, false)
534        }
535    }
536
537    fn parameters(&self) -> Vec<Parameter> {
538        vec![self.weight.clone(), self.bias.clone()]
539    }
540
541    fn named_parameters(&self) -> HashMap<String, Parameter> {
542        let mut params = HashMap::new();
543        params.insert("weight".to_string(), self.weight.clone());
544        params.insert("bias".to_string(), self.bias.clone());
545        params
546    }
547
548    fn set_training(&mut self, training: bool) {
549        self.training.store(training, Ordering::Relaxed);
550    }
551
552    fn is_training(&self) -> bool {
553        self.training.load(Ordering::Relaxed)
554    }
555
556    fn name(&self) -> &'static str {
557        "BatchNorm2d"
558    }
559}
560
561// =============================================================================
562// LayerNorm
563// =============================================================================
564
565/// Applies Layer Normalization over the last D dimensions.
566///
567/// y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
568///
569/// Unlike BatchNorm, LayerNorm normalizes over features, not batch.
570pub struct LayerNorm {
571    /// Learnable scale parameter (gamma).
572    pub weight: Parameter,
573    /// Learnable shift parameter (beta).
574    pub bias: Parameter,
575    /// Normalized shape.
576    normalized_shape: Vec<usize>,
577    /// Epsilon for numerical stability.
578    eps: f32,
579}
580
581impl LayerNorm {
582    /// Creates a new LayerNorm layer.
583    pub fn new(normalized_shape: Vec<usize>) -> Self {
584        Self::with_eps(normalized_shape, 1e-5)
585    }
586
587    /// Creates a LayerNorm for a single dimension.
588    pub fn single(size: usize) -> Self {
589        Self::new(vec![size])
590    }
591
592    /// Creates a LayerNorm with custom epsilon.
593    pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
594        let numel: usize = normalized_shape.iter().product();
595        Self {
596            weight: Parameter::named("weight", ones(&[numel]), true),
597            bias: Parameter::named("bias", zeros(&[numel]), true),
598            normalized_shape,
599            eps,
600        }
601    }
602}
603
604impl Module for LayerNorm {
605    fn forward(&self, input: &Variable) -> Variable {
606        let input_data = input.data();
607        let shape = input_data.shape().to_vec();
608        let norm_size: usize = self.normalized_shape.iter().product();
609        let total_len = input_data.numel();
610        let num_rows = total_len / norm_size;
611
612        // GPU fast path: run LayerNorm entirely on GPU via CUDA kernel
613        #[cfg(feature = "cuda")]
614        if input_data.device().is_gpu() {
615            // Ensure weight and bias are on GPU
616            let weight_data = self.weight.data();
617            let weight_gpu = if weight_data.device().is_gpu() {
618                weight_data.clone()
619            } else {
620                weight_data.to_device(input_data.device().clone()).unwrap()
621            };
622            let bias_data = self.bias.data();
623            let bias_gpu = if bias_data.device().is_gpu() {
624                bias_data.clone()
625            } else {
626                bias_data.to_device(input_data.device().clone()).unwrap()
627            };
628
629            let output = input_data
630                .layer_norm_cuda(&weight_gpu, &bias_gpu, norm_size, self.eps)
631                .expect("CUDA LayerNorm failed");
632
633            let requires_grad = input.requires_grad() && is_grad_enabled();
634            return if requires_grad {
635                let grad_fn = GradFn::new(LayerNormBackward::new(
636                    input.grad_fn().cloned(),
637                    self.weight.variable().grad_fn().cloned(),
638                    self.bias.variable().grad_fn().cloned(),
639                    input_data.clone(),
640                    self.weight.data().clone(),
641                    self.normalized_shape.clone(),
642                    self.eps,
643                ));
644                Variable::from_operation(output, grad_fn, true)
645            } else {
646                Variable::from_tensor(output)
647            };
648        }
649
650        // CPU path
651        let input_vec = input_data.to_vec();
652        let weight_vec = self.weight.data().to_vec();
653        let bias_vec = self.bias.data().to_vec();
654
655        let mut output_vec = vec![0.0f32; input_vec.len()];
656
657        for b in 0..num_rows {
658            let start = b * norm_size;
659            let end = start + norm_size;
660            let slice = &input_vec[start..end];
661
662            let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
663            let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
664
665            for i in 0..norm_size {
666                let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
667                output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
668            }
669        }
670
671        let output = Tensor::from_vec(output_vec, &shape).unwrap();
672        let requires_grad = input.requires_grad() && is_grad_enabled();
673
674        if requires_grad {
675            let grad_fn = GradFn::new(LayerNormBackward::new(
676                input.grad_fn().cloned(),
677                self.weight.variable().grad_fn().cloned(),
678                self.bias.variable().grad_fn().cloned(),
679                input_data.clone(),
680                self.weight.data().clone(),
681                self.normalized_shape.clone(),
682                self.eps,
683            ));
684            Variable::from_operation(output, grad_fn, true)
685        } else {
686            Variable::from_tensor(output)
687        }
688    }
689
690    fn parameters(&self) -> Vec<Parameter> {
691        vec![self.weight.clone(), self.bias.clone()]
692    }
693
694    fn named_parameters(&self) -> HashMap<String, Parameter> {
695        let mut params = HashMap::new();
696        params.insert("weight".to_string(), self.weight.clone());
697        params.insert("bias".to_string(), self.bias.clone());
698        params
699    }
700
701    fn name(&self) -> &'static str {
702        "LayerNorm"
703    }
704}
705
706// =============================================================================
707// GroupNorm
708// =============================================================================
709
710/// Applies Group Normalization over a mini-batch of inputs.
711///
712/// Groups channels and normalizes within each group.
713/// Particularly effective for small batch sizes where BatchNorm struggles.
714///
715/// # Shape
716/// - Input: (N, C, *) where C must be divisible by num_groups
717/// - Output: Same as input
718pub struct GroupNorm {
719    /// Learnable scale parameter (gamma).
720    pub weight: Parameter,
721    /// Learnable shift parameter (beta).
722    pub bias: Parameter,
723    /// Number of groups to divide channels into.
724    num_groups: usize,
725    /// Number of channels expected in input.
726    num_channels: usize,
727    /// Epsilon for numerical stability.
728    eps: f32,
729    /// Whether to use learnable affine parameters.
730    affine: bool,
731}
732
733impl GroupNorm {
734    /// Creates a new GroupNorm layer.
735    ///
736    /// # Arguments
737    /// * `num_groups` - Number of groups to divide channels into
738    /// * `num_channels` - Number of channels expected in input
739    pub fn new(num_groups: usize, num_channels: usize) -> Self {
740        Self::with_options(num_groups, num_channels, 1e-5, true)
741    }
742
743    /// Creates a GroupNorm with custom options.
744    pub fn with_options(num_groups: usize, num_channels: usize, eps: f32, affine: bool) -> Self {
745        assert!(
746            num_channels % num_groups == 0,
747            "num_channels ({}) must be divisible by num_groups ({})",
748            num_channels,
749            num_groups
750        );
751
752        Self {
753            weight: Parameter::named("weight", ones(&[num_channels]), affine),
754            bias: Parameter::named("bias", zeros(&[num_channels]), affine),
755            num_groups,
756            num_channels,
757            eps,
758            affine,
759        }
760    }
761}
762
763impl Module for GroupNorm {
764    fn forward(&self, input: &Variable) -> Variable {
765        let input_data = input.data();
766        let shape = input_data.shape().to_vec();
767        let batch_size = shape[0];
768        let channels = shape[1];
769        let spatial_size: usize = shape[2..].iter().product();
770
771        assert_eq!(
772            channels, self.num_channels,
773            "GroupNorm: expected {} channels, got {}",
774            self.num_channels, channels
775        );
776
777        let input_vec = input_data.to_vec();
778        let channels_per_group = channels / self.num_groups;
779
780        let mut output_vec = vec![0.0f32; input_vec.len()];
781
782        for b in 0..batch_size {
783            for g in 0..self.num_groups {
784                // Calculate mean and variance for this group
785                let mut sum = 0.0f32;
786                let group_size = channels_per_group * spatial_size;
787
788                for c in 0..channels_per_group {
789                    let channel_idx = g * channels_per_group + c;
790                    for s in 0..spatial_size {
791                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
792                        sum += input_vec[idx];
793                    }
794                }
795                let mean = sum / group_size as f32;
796
797                let mut var_sum = 0.0f32;
798                for c in 0..channels_per_group {
799                    let channel_idx = g * channels_per_group + c;
800                    for s in 0..spatial_size {
801                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
802                        let diff = input_vec[idx] - mean;
803                        var_sum += diff * diff;
804                    }
805                }
806                let var = var_sum / group_size as f32;
807
808                // Normalize
809                let std_inv = 1.0 / (var + self.eps).sqrt();
810                for c in 0..channels_per_group {
811                    let channel_idx = g * channels_per_group + c;
812                    let weight = if self.affine {
813                        self.weight.data().to_vec()[channel_idx]
814                    } else {
815                        1.0
816                    };
817                    let bias = if self.affine {
818                        self.bias.data().to_vec()[channel_idx]
819                    } else {
820                        0.0
821                    };
822
823                    for s in 0..spatial_size {
824                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
825                        let normalized = (input_vec[idx] - mean) * std_inv;
826                        output_vec[idx] = normalized * weight + bias;
827                    }
828                }
829            }
830        }
831
832        let output = Tensor::from_vec(output_vec, &shape).unwrap();
833        let requires_grad = input.requires_grad() && is_grad_enabled();
834
835        if requires_grad && self.affine {
836            let grad_fn = GradFn::new(GroupNormBackward::new(
837                input.grad_fn().cloned(),
838                self.weight.variable().grad_fn().cloned(),
839                self.bias.variable().grad_fn().cloned(),
840                input_data.clone(),
841                self.weight.data().clone(),
842                self.num_groups,
843                self.eps,
844            ));
845            Variable::from_operation(output, grad_fn, true)
846        } else {
847            Variable::from_tensor(output)
848        }
849    }
850
851    fn parameters(&self) -> Vec<Parameter> {
852        if self.affine {
853            vec![self.weight.clone(), self.bias.clone()]
854        } else {
855            vec![]
856        }
857    }
858
859    fn named_parameters(&self) -> HashMap<String, Parameter> {
860        if self.affine {
861            let mut params = HashMap::new();
862            params.insert("weight".to_string(), self.weight.clone());
863            params.insert("bias".to_string(), self.bias.clone());
864            params
865        } else {
866            HashMap::new()
867        }
868    }
869
870    fn name(&self) -> &'static str {
871        "GroupNorm"
872    }
873}
874
875// =============================================================================
876// InstanceNorm2d
877// =============================================================================
878
879/// Applies Instance Normalization over a 4D input (images).
880///
881/// Each channel in each sample is normalized independently.
882/// Particularly useful for style transfer and image generation.
883///
884/// # Shape
885/// - Input: (N, C, H, W)
886/// - Output: Same as input
887pub struct InstanceNorm2d {
888    /// Learnable scale parameter (gamma).
889    pub weight: Parameter,
890    /// Learnable shift parameter (beta).
891    pub bias: Parameter,
892    /// Number of features (channels).
893    num_features: usize,
894    /// Epsilon for numerical stability.
895    eps: f32,
896    /// Whether to use learnable affine parameters.
897    affine: bool,
898}
899
900impl InstanceNorm2d {
901    /// Creates a new InstanceNorm2d layer.
902    pub fn new(num_features: usize) -> Self {
903        Self::with_options(num_features, 1e-5, false)
904    }
905
906    /// Creates an InstanceNorm2d with affine parameters.
907    pub fn with_affine(num_features: usize) -> Self {
908        Self::with_options(num_features, 1e-5, true)
909    }
910
911    /// Creates an InstanceNorm2d with custom options.
912    pub fn with_options(num_features: usize, eps: f32, affine: bool) -> Self {
913        Self {
914            weight: Parameter::named("weight", ones(&[num_features]), affine),
915            bias: Parameter::named("bias", zeros(&[num_features]), affine),
916            num_features,
917            eps,
918            affine,
919        }
920    }
921}
922
923impl Module for InstanceNorm2d {
924    fn forward(&self, input: &Variable) -> Variable {
925        let input_data = input.data();
926        let shape = input_data.shape().to_vec();
927
928        assert!(
929            shape.len() == 4,
930            "InstanceNorm2d expects 4D input (N, C, H, W)"
931        );
932
933        let batch_size = shape[0];
934        let channels = shape[1];
935        let height = shape[2];
936        let width = shape[3];
937        let spatial_size = height * width;
938
939        assert_eq!(
940            channels, self.num_features,
941            "InstanceNorm2d: expected {} channels, got {}",
942            self.num_features, channels
943        );
944
945        let input_vec = input_data.to_vec();
946        let mut output_vec = vec![0.0f32; input_vec.len()];
947
948        for b in 0..batch_size {
949            for c in 0..channels {
950                // Calculate mean for this (batch, channel) pair
951                let mut sum = 0.0f32;
952                for s in 0..spatial_size {
953                    let idx = b * channels * spatial_size + c * spatial_size + s;
954                    sum += input_vec[idx];
955                }
956                let mean = sum / spatial_size as f32;
957
958                // Calculate variance
959                let mut var_sum = 0.0f32;
960                for s in 0..spatial_size {
961                    let idx = b * channels * spatial_size + c * spatial_size + s;
962                    let diff = input_vec[idx] - mean;
963                    var_sum += diff * diff;
964                }
965                let var = var_sum / spatial_size as f32;
966
967                // Normalize and apply affine
968                let std_inv = 1.0 / (var + self.eps).sqrt();
969                let weight = if self.affine {
970                    self.weight.data().to_vec()[c]
971                } else {
972                    1.0
973                };
974                let bias = if self.affine {
975                    self.bias.data().to_vec()[c]
976                } else {
977                    0.0
978                };
979
980                for s in 0..spatial_size {
981                    let idx = b * channels * spatial_size + c * spatial_size + s;
982                    let normalized = (input_vec[idx] - mean) * std_inv;
983                    output_vec[idx] = normalized * weight + bias;
984                }
985            }
986        }
987
988        let output = Tensor::from_vec(output_vec, &shape).unwrap();
989        let requires_grad = input.requires_grad() && is_grad_enabled();
990
991        if requires_grad {
992            let grad_fn = GradFn::new(InstanceNorm2dBackward::new(
993                input.grad_fn().cloned(),
994                if self.affine {
995                    self.weight.variable().grad_fn().cloned()
996                } else {
997                    None
998                },
999                if self.affine {
1000                    self.bias.variable().grad_fn().cloned()
1001                } else {
1002                    None
1003                },
1004                input_data.clone(),
1005                self.weight.data().clone(),
1006                self.eps,
1007                self.affine,
1008            ));
1009            Variable::from_operation(output, grad_fn, true)
1010        } else {
1011            Variable::from_tensor(output)
1012        }
1013    }
1014
1015    fn parameters(&self) -> Vec<Parameter> {
1016        if self.affine {
1017            vec![self.weight.clone(), self.bias.clone()]
1018        } else {
1019            vec![]
1020        }
1021    }
1022
1023    fn named_parameters(&self) -> HashMap<String, Parameter> {
1024        if self.affine {
1025            let mut params = HashMap::new();
1026            params.insert("weight".to_string(), self.weight.clone());
1027            params.insert("bias".to_string(), self.bias.clone());
1028            params
1029        } else {
1030            HashMap::new()
1031        }
1032    }
1033
1034    fn name(&self) -> &'static str {
1035        "InstanceNorm2d"
1036    }
1037}
1038
1039// =============================================================================
1040// Tests
1041// =============================================================================
1042
1043#[cfg(test)]
1044mod tests {
1045    use super::*;
1046
1047    #[test]
1048    fn test_batchnorm1d() {
1049        let bn = BatchNorm1d::new(3);
1050        let input = Variable::new(
1051            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
1052            false,
1053        );
1054        let output = bn.forward(&input);
1055        assert_eq!(output.shape(), vec![2, 3]);
1056    }
1057
1058    #[test]
1059    fn test_batchnorm2d() {
1060        let bn = BatchNorm2d::new(2);
1061        let input = Variable::new(
1062            Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
1063            false,
1064        );
1065        let output = bn.forward(&input);
1066        assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1067    }
1068
1069    #[test]
1070    fn test_layernorm() {
1071        let ln = LayerNorm::single(4);
1072        let input = Variable::new(
1073            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap(),
1074            false,
1075        );
1076        let output = ln.forward(&input);
1077        assert_eq!(output.shape(), vec![2, 4]);
1078    }
1079
1080    #[test]
1081    fn test_batchnorm_parameters() {
1082        let bn = BatchNorm1d::new(10);
1083        assert_eq!(bn.parameters().len(), 2);
1084        assert_eq!(bn.num_parameters(), 20); // weight + bias
1085    }
1086
1087    #[test]
1088    fn test_groupnorm() {
1089        let gn = GroupNorm::new(2, 4); // 2 groups, 4 channels
1090        let input = Variable::new(
1091            Tensor::from_vec(vec![1.0; 32], &[2, 4, 2, 2]).unwrap(),
1092            false,
1093        );
1094        let output = gn.forward(&input);
1095        assert_eq!(output.shape(), vec![2, 4, 2, 2]);
1096    }
1097
1098    #[test]
1099    fn test_groupnorm_normalization() {
1100        let gn = GroupNorm::with_options(2, 4, 1e-5, false); // No affine
1101        let input = Variable::new(
1102            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 4, 1, 2]).unwrap(),
1103            false,
1104        );
1105        let output = gn.forward(&input);
1106        // After normalization within groups, should have zero mean
1107        let out_vec = output.data().to_vec();
1108        // Group 1: channels 0,1 (vals 1,2,3,4) and Group 2: channels 2,3 (vals 5,6,7,8)
1109        let group1_mean: f32 = out_vec[0..4].iter().sum::<f32>() / 4.0;
1110        let group2_mean: f32 = out_vec[4..8].iter().sum::<f32>() / 4.0;
1111        assert!(group1_mean.abs() < 1e-5);
1112        assert!(group2_mean.abs() < 1e-5);
1113    }
1114
1115    #[test]
1116    fn test_instancenorm2d() {
1117        let inn = InstanceNorm2d::new(2);
1118        let input = Variable::new(
1119            Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
1120            false,
1121        );
1122        let output = inn.forward(&input);
1123        assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1124    }
1125
1126    #[test]
1127    fn test_instancenorm2d_with_affine() {
1128        let inn = InstanceNorm2d::with_affine(4);
1129        let input = Variable::new(
1130            Tensor::from_vec(vec![1.0; 64], &[1, 4, 4, 4]).unwrap(),
1131            false,
1132        );
1133        let output = inn.forward(&input);
1134        assert_eq!(output.shape(), vec![1, 4, 4, 4]);
1135        assert_eq!(inn.parameters().len(), 2);
1136    }
1137}