Skip to main content

axonml_nn/layers/
norm.rs

1//! Normalization Layers
2//!
3//! # File
4//! `crates/axonml-nn/src/layers/norm.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::Variable;
21use axonml_autograd::functions::{
22    BatchNorm1dBackward, BatchNorm2dBackward, GroupNormBackward, InstanceNorm2dBackward,
23    LayerNormBackward,
24};
25use axonml_autograd::grad_fn::GradFn;
26use axonml_autograd::no_grad::is_grad_enabled;
27use axonml_tensor::Tensor;
28use parking_lot::RwLock;
29
30use crate::init::{ones, zeros};
31use crate::module::Module;
32use crate::parameter::Parameter;
33
34// =============================================================================
35// BatchNorm1d
36// =============================================================================
37
38/// Applies Batch Normalization over a 2D or 3D input.
39///
40/// y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
41///
42/// # Shape
43/// - Input: (N, C) or (N, C, L)
44/// - Output: Same as input
45pub struct BatchNorm1d {
46    /// Learnable scale parameter (gamma).
47    pub weight: Parameter,
48    /// Learnable shift parameter (beta).
49    pub bias: Parameter,
50    /// Running mean for inference (updated during training).
51    running_mean: RwLock<Tensor<f32>>,
52    /// Running variance for inference (updated during training).
53    running_var: RwLock<Tensor<f32>>,
54    /// Number of features.
55    num_features: usize,
56    /// Epsilon for numerical stability.
57    eps: f32,
58    /// Momentum for running stats update: running = (1 - momentum) * running + momentum * batch.
59    momentum: f32,
60    /// Whether to track running stats.
61    track_running_stats: bool,
62    /// Whether in training mode.
63    training: AtomicBool,
64}
65
66impl BatchNorm1d {
67    /// Creates a new BatchNorm1d layer.
68    pub fn new(num_features: usize) -> Self {
69        Self::with_options(num_features, 1e-5, 0.1, true)
70    }
71
72    /// Creates a BatchNorm1d with custom options.
73    pub fn with_options(
74        num_features: usize,
75        eps: f32,
76        momentum: f32,
77        track_running_stats: bool,
78    ) -> Self {
79        Self {
80            weight: Parameter::named("weight", ones(&[num_features]), true),
81            bias: Parameter::named("bias", zeros(&[num_features]), true),
82            running_mean: RwLock::new(zeros(&[num_features])),
83            running_var: RwLock::new(ones(&[num_features])),
84            num_features,
85            eps,
86            momentum,
87            track_running_stats,
88            training: AtomicBool::new(true),
89        }
90    }
91
92    /// Returns the number of features.
93    pub fn num_features(&self) -> usize {
94        self.num_features
95    }
96}
97
98impl Module for BatchNorm1d {
99    fn forward(&self, input: &Variable) -> Variable {
100        let input_data = input.data();
101        let shape = input_data.shape().to_vec();
102        let batch_size = shape[0];
103        let num_features = shape[1];
104
105        // Validate input matches expected features
106        assert_eq!(
107            num_features, self.num_features,
108            "BatchNorm1d: expected {} features, got {}",
109            self.num_features, num_features
110        );
111
112        let is_training = self.training.load(Ordering::Relaxed);
113        let spatial_size: usize = if shape.len() > 2 {
114            shape[2..].iter().product()
115        } else {
116            1
117        };
118
119        // GPU fast path: use fused batchnorm kernels when input is on GPU.
120        // For [batch, features] layout, spatial=1. The kernel indexes as
121        // (idx / spatial) % C which with spatial=1 becomes idx % C — correct
122        // for [batch, features] since it's the same layout as [batch, features, 1].
123        #[cfg(feature = "cuda")]
124        if input_data.device().is_gpu() && is_training {
125            let gamma_data = self.weight.data();
126            let beta_data = self.bias.data();
127
128            // Auto-migrate weight/bias to GPU if needed
129            let gamma_gpu = if !gamma_data.device().is_gpu() {
130                gamma_data
131                    .to_device(input_data.device())
132                    .unwrap_or(gamma_data)
133            } else {
134                gamma_data
135            };
136            let beta_gpu = if !beta_data.device().is_gpu() {
137                beta_data
138                    .to_device(input_data.device())
139                    .unwrap_or(beta_data)
140            } else {
141                beta_data
142            };
143
144            if let Some((output_tensor, means, vars)) = input_data.batchnorm_fused(
145                &gamma_gpu,
146                &beta_gpu,
147                self.eps,
148                num_features,
149                spatial_size,
150            ) {
151                // Update running statistics
152                if self.track_running_stats {
153                    let mut running_mean = self.running_mean.write();
154                    let mut running_var = self.running_var.write();
155                    let running_mean_vec = running_mean.to_vec();
156                    let running_var_vec = running_var.to_vec();
157                    let new_mean: Vec<f32> = running_mean_vec
158                        .iter()
159                        .zip(means.iter())
160                        .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
161                        .collect();
162                    let new_var: Vec<f32> = running_var_vec
163                        .iter()
164                        .zip(vars.iter())
165                        .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
166                        .collect();
167                    *running_mean = Tensor::from_vec(new_mean, &[num_features]).expect("tensor creation failed");
168                    *running_var = Tensor::from_vec(new_var, &[num_features]).expect("tensor creation failed");
169                }
170
171                let weight_vec = gamma_gpu.to_vec();
172                let requires_grad =
173                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
174                if requires_grad {
175                    let weight_var = self.weight.variable();
176                    let bias_var = self.bias.variable();
177                    let grad_fn = GradFn::new(BatchNorm1dBackward::new(
178                        input.grad_fn().cloned(),
179                        weight_var.grad_fn().cloned(),
180                        bias_var.grad_fn().cloned(),
181                        input_data,
182                        means,
183                        vars,
184                        weight_vec,
185                        self.eps,
186                        self.num_features,
187                    ));
188                    return Variable::from_operation(output_tensor, grad_fn, true);
189                }
190                return Variable::new(output_tensor, false);
191            }
192        }
193
194        let input_vec = input_data.to_vec();
195        let weight_vec = self.weight.data().to_vec();
196        let bias_vec = self.bias.data().to_vec();
197
198        let mut means = vec![0.0f32; num_features];
199        let mut vars = vec![0.0f32; num_features];
200
201        if is_training {
202            // Calculate batch statistics
203            for c in 0..num_features {
204                let mut sum = 0.0f32;
205                for b in 0..batch_size {
206                    for s in 0..spatial_size {
207                        let idx = b * num_features * spatial_size + c * spatial_size + s;
208                        sum += input_vec[idx];
209                    }
210                }
211                means[c] = sum / (batch_size * spatial_size) as f32;
212
213                let mut var_sum = 0.0f32;
214                for b in 0..batch_size {
215                    for s in 0..spatial_size {
216                        let idx = b * num_features * spatial_size + c * spatial_size + s;
217                        let diff = input_vec[idx] - means[c];
218                        var_sum += diff * diff;
219                    }
220                }
221                vars[c] = var_sum / (batch_size * spatial_size) as f32;
222            }
223
224            // Update running statistics if tracking is enabled
225            if self.track_running_stats {
226                let mut running_mean = self.running_mean.write();
227                let mut running_var = self.running_var.write();
228                let running_mean_vec = running_mean.to_vec();
229                let running_var_vec = running_var.to_vec();
230
231                let new_mean: Vec<f32> = running_mean_vec
232                    .iter()
233                    .zip(means.iter())
234                    .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
235                    .collect();
236                let new_var: Vec<f32> = running_var_vec
237                    .iter()
238                    .zip(vars.iter())
239                    .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
240                    .collect();
241
242                *running_mean = Tensor::from_vec(new_mean, &[num_features]).expect("tensor creation failed");
243                *running_var = Tensor::from_vec(new_var, &[num_features]).expect("tensor creation failed");
244            }
245        } else {
246            // Use running statistics for inference
247            means = self.running_mean.read().to_vec();
248            vars = self.running_var.read().to_vec();
249        }
250
251        // Normalize: y = (x - mean) / sqrt(var + eps) * weight + bias
252        let mut output_vec = vec![0.0f32; input_vec.len()];
253        for b in 0..batch_size {
254            for c in 0..num_features {
255                for s in 0..spatial_size {
256                    let idx = b * num_features * spatial_size + c * spatial_size + s;
257                    let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
258                    output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
259                }
260            }
261        }
262
263        let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
264
265        let requires_grad =
266            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
267        if requires_grad {
268            let weight_var = self.weight.variable();
269            let bias_var = self.bias.variable();
270
271            let grad_fn = GradFn::new(BatchNorm1dBackward::new(
272                input.grad_fn().cloned(),
273                weight_var.grad_fn().cloned(),
274                bias_var.grad_fn().cloned(),
275                input_data,
276                means.clone(),
277                vars.clone(),
278                weight_vec,
279                self.eps,
280                self.num_features,
281            ));
282            Variable::from_operation(output, grad_fn, true)
283        } else {
284            Variable::new(output, false)
285        }
286    }
287
288    fn parameters(&self) -> Vec<Parameter> {
289        vec![self.weight.clone(), self.bias.clone()]
290    }
291
292    fn named_parameters(&self) -> HashMap<String, Parameter> {
293        let mut params = HashMap::new();
294        params.insert("weight".to_string(), self.weight.clone());
295        params.insert("bias".to_string(), self.bias.clone());
296        params
297    }
298
299    fn set_training(&mut self, training: bool) {
300        self.training.store(training, Ordering::Relaxed);
301    }
302
303    fn is_training(&self) -> bool {
304        self.training.load(Ordering::Relaxed)
305    }
306
307    fn name(&self) -> &'static str {
308        "BatchNorm1d"
309    }
310
311    fn to_device(&self, device: axonml_core::Device) {
312        // Move parameters
313        for param in self.parameters() {
314            param.to_device(device);
315        }
316        // Move running statistics (non-parameter buffers)
317        if self.track_running_stats {
318            let mut rm = self.running_mean.write();
319            if let Ok(moved) = rm.to_device(device) {
320                *rm = moved;
321            }
322            let mut rv = self.running_var.write();
323            if let Ok(moved) = rv.to_device(device) {
324                *rv = moved;
325            }
326        }
327    }
328}
329
330// =============================================================================
331// BatchNorm2d
332// =============================================================================
333
334/// Applies Batch Normalization over a 4D input (images).
335///
336/// # Shape
337/// - Input: (N, C, H, W)
338/// - Output: Same as input
339pub struct BatchNorm2d {
340    /// Learnable scale parameter (gamma).
341    pub weight: Parameter,
342    /// Learnable shift parameter (beta).
343    pub bias: Parameter,
344    /// Running mean for inference (updated during training).
345    running_mean: RwLock<Tensor<f32>>,
346    /// Running variance for inference (updated during training).
347    running_var: RwLock<Tensor<f32>>,
348    /// Number of features (channels).
349    num_features: usize,
350    /// Epsilon for numerical stability.
351    eps: f32,
352    /// Momentum for running stats update.
353    momentum: f32,
354    /// Whether in training mode.
355    training: AtomicBool,
356}
357
358impl BatchNorm2d {
359    /// Creates a new BatchNorm2d layer.
360    pub fn new(num_features: usize) -> Self {
361        Self::with_options(num_features, 1e-5, 0.1)
362    }
363
364    /// Creates a BatchNorm2d with custom options.
365    pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
366        Self {
367            weight: Parameter::named("weight", ones(&[num_features]), true),
368            bias: Parameter::named("bias", zeros(&[num_features]), true),
369            running_mean: RwLock::new(zeros(&[num_features])),
370            running_var: RwLock::new(ones(&[num_features])),
371            num_features,
372            eps,
373            momentum,
374            training: AtomicBool::new(true),
375        }
376    }
377
378    /// Returns the number of features (channels).
379    pub fn num_features(&self) -> usize {
380        self.num_features
381    }
382}
383
384impl Module for BatchNorm2d {
385    fn forward(&self, input: &Variable) -> Variable {
386        let input_data = input.data();
387        let shape = input_data.shape().to_vec();
388        let batch_size = shape[0];
389        let channels = shape[1];
390        let height = shape[2];
391        let width = shape[3];
392        let spatial_size = height * width;
393
394        // Validate input matches expected channels
395        assert_eq!(
396            channels, self.num_features,
397            "BatchNorm2d: expected {} channels, got {}",
398            self.num_features, channels
399        );
400
401        let is_training = self.training.load(Ordering::Relaxed);
402
403        // GPU fast path: use fused batchnorm kernels when input is on GPU
404        #[cfg(feature = "cuda")]
405        if input_data.device().is_gpu() && is_training {
406            let gamma_data = self.weight.data();
407            let beta_data = self.bias.data();
408
409            // Auto-migrate weight/bias to GPU if needed
410            let gamma_gpu = if !gamma_data.device().is_gpu() {
411                gamma_data
412                    .to_device(input_data.device())
413                    .unwrap_or(gamma_data)
414            } else {
415                gamma_data
416            };
417            let beta_gpu = if !beta_data.device().is_gpu() {
418                beta_data
419                    .to_device(input_data.device())
420                    .unwrap_or(beta_data)
421            } else {
422                beta_data
423            };
424
425            if let Some((output_tensor, means, vars)) =
426                input_data.batchnorm_fused(&gamma_gpu, &beta_gpu, self.eps, channels, spatial_size)
427            {
428                // Update running statistics
429                let mut running_mean = self.running_mean.write();
430                let mut running_var = self.running_var.write();
431                let running_mean_vec = running_mean.to_vec();
432                let running_var_vec = running_var.to_vec();
433                let new_mean: Vec<f32> = running_mean_vec
434                    .iter()
435                    .zip(means.iter())
436                    .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
437                    .collect();
438                let new_var: Vec<f32> = running_var_vec
439                    .iter()
440                    .zip(vars.iter())
441                    .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
442                    .collect();
443                *running_mean = Tensor::from_vec(new_mean, &[channels]).expect("tensor creation failed");
444                *running_var = Tensor::from_vec(new_var, &[channels]).expect("tensor creation failed");
445
446                let weight_vec = gamma_gpu.to_vec();
447                let requires_grad =
448                    (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
449                if requires_grad {
450                    let weight_var = self.weight.variable();
451                    let bias_var = self.bias.variable();
452                    let grad_fn = GradFn::new(BatchNorm2dBackward::new(
453                        input.grad_fn().cloned(),
454                        weight_var.grad_fn().cloned(),
455                        bias_var.grad_fn().cloned(),
456                        input_data,
457                        means,
458                        vars,
459                        weight_vec,
460                        self.eps,
461                        self.num_features,
462                    ));
463                    return Variable::from_operation(output_tensor, grad_fn, true);
464                }
465                return Variable::new(output_tensor, false);
466            }
467        }
468
469        // CPU path
470        let input_vec = input_data.to_vec();
471        let weight_vec = self.weight.data().to_vec();
472        let bias_vec = self.bias.data().to_vec();
473
474        let mut means = vec![0.0f32; channels];
475        let mut vars = vec![0.0f32; channels];
476
477        if is_training {
478            let n_per_channel = (batch_size * spatial_size) as f32;
479            for c in 0..channels {
480                let mut sum = 0.0f32;
481                let mut sum_sq = 0.0f32;
482                for b in 0..batch_size {
483                    let base = b * channels * spatial_size + c * spatial_size;
484                    for s in 0..spatial_size {
485                        let val = input_vec[base + s];
486                        sum += val;
487                        sum_sq += val * val;
488                    }
489                }
490                means[c] = sum / n_per_channel;
491                vars[c] = sum_sq / n_per_channel - means[c] * means[c];
492            }
493
494            // Update running statistics
495            let mut running_mean = self.running_mean.write();
496            let mut running_var = self.running_var.write();
497            let running_mean_vec = running_mean.to_vec();
498            let running_var_vec = running_var.to_vec();
499
500            let new_mean: Vec<f32> = running_mean_vec
501                .iter()
502                .zip(means.iter())
503                .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
504                .collect();
505            let new_var: Vec<f32> = running_var_vec
506                .iter()
507                .zip(vars.iter())
508                .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
509                .collect();
510
511            *running_mean = Tensor::from_vec(new_mean, &[channels]).expect("tensor creation failed");
512            *running_var = Tensor::from_vec(new_var, &[channels]).expect("tensor creation failed");
513        } else {
514            means = self.running_mean.read().to_vec();
515            vars = self.running_var.read().to_vec();
516        }
517
518        // Normalize + affine transform (optimized single-pass)
519        let total = input_vec.len();
520        let mut output_vec = vec![0.0f32; total];
521
522        // Pre-compute inv_std per channel to avoid repeated sqrt
523        let inv_stds: Vec<f32> = vars.iter().map(|v| 1.0 / (v + self.eps).sqrt()).collect();
524
525        for i in 0..total {
526            let c = (i / spatial_size) % channels;
527            output_vec[i] = (input_vec[i] - means[c]) * inv_stds[c] * weight_vec[c] + bias_vec[c];
528        }
529
530        let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
531
532        let requires_grad =
533            (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
534        if requires_grad {
535            let weight_var = self.weight.variable();
536            let bias_var = self.bias.variable();
537
538            let grad_fn = GradFn::new(BatchNorm2dBackward::new(
539                input.grad_fn().cloned(),
540                weight_var.grad_fn().cloned(),
541                bias_var.grad_fn().cloned(),
542                input_data,
543                means.clone(),
544                vars.clone(),
545                weight_vec,
546                self.eps,
547                self.num_features,
548            ));
549            Variable::from_operation(output, grad_fn, true)
550        } else {
551            Variable::new(output, false)
552        }
553    }
554
555    fn parameters(&self) -> Vec<Parameter> {
556        vec![self.weight.clone(), self.bias.clone()]
557    }
558
559    fn named_parameters(&self) -> HashMap<String, Parameter> {
560        let mut params = HashMap::new();
561        params.insert("weight".to_string(), self.weight.clone());
562        params.insert("bias".to_string(), self.bias.clone());
563        params
564    }
565
566    fn set_training(&mut self, training: bool) {
567        self.training.store(training, Ordering::Relaxed);
568    }
569
570    fn is_training(&self) -> bool {
571        self.training.load(Ordering::Relaxed)
572    }
573
574    fn name(&self) -> &'static str {
575        "BatchNorm2d"
576    }
577
578    fn to_device(&self, device: axonml_core::Device) {
579        for param in self.parameters() {
580            param.to_device(device);
581        }
582        // Move running statistics (non-parameter buffers)
583        let mut rm = self.running_mean.write();
584        if let Ok(moved) = rm.to_device(device) {
585            *rm = moved;
586        }
587        let mut rv = self.running_var.write();
588        if let Ok(moved) = rv.to_device(device) {
589            *rv = moved;
590        }
591    }
592}
593
594// =============================================================================
595// LayerNorm
596// =============================================================================
597
598/// Applies Layer Normalization over the last D dimensions.
599///
600/// y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
601///
602/// Unlike BatchNorm, LayerNorm normalizes over features, not batch.
603pub struct LayerNorm {
604    /// Learnable scale parameter (gamma).
605    pub weight: Parameter,
606    /// Learnable shift parameter (beta).
607    pub bias: Parameter,
608    /// Normalized shape.
609    normalized_shape: Vec<usize>,
610    /// Epsilon for numerical stability.
611    eps: f32,
612}
613
614impl LayerNorm {
615    /// Creates a new LayerNorm layer.
616    pub fn new(normalized_shape: Vec<usize>) -> Self {
617        Self::with_eps(normalized_shape, 1e-5)
618    }
619
620    /// Creates a LayerNorm for a single dimension.
621    pub fn single(size: usize) -> Self {
622        Self::new(vec![size])
623    }
624
625    /// Creates a LayerNorm with custom epsilon.
626    pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
627        let numel: usize = normalized_shape.iter().product();
628        Self {
629            weight: Parameter::named("weight", ones(&[numel]), true),
630            bias: Parameter::named("bias", zeros(&[numel]), true),
631            normalized_shape,
632            eps,
633        }
634    }
635}
636
637impl Module for LayerNorm {
638    fn forward(&self, input: &Variable) -> Variable {
639        let input_data = input.data();
640        let shape = input_data.shape().to_vec();
641        let norm_size: usize = self.normalized_shape.iter().product();
642        let total_len = input_data.numel();
643        let num_rows = total_len / norm_size;
644
645        // GPU fast path: run LayerNorm entirely on GPU via CUDA kernel
646        #[cfg(feature = "cuda")]
647        if input_data.device().is_gpu() {
648            // Ensure weight and bias are on GPU
649            let weight_data = self.weight.data();
650            let weight_gpu = if weight_data.device().is_gpu() {
651                weight_data.clone()
652            } else {
653                weight_data.to_device(input_data.device().clone()).unwrap()
654            };
655            let bias_data = self.bias.data();
656            let bias_gpu = if bias_data.device().is_gpu() {
657                bias_data.clone()
658            } else {
659                bias_data.to_device(input_data.device().clone()).unwrap()
660            };
661
662            let output = input_data
663                .layer_norm_cuda(&weight_gpu, &bias_gpu, norm_size, self.eps)
664                .expect("CUDA LayerNorm failed");
665
666            let requires_grad = input.requires_grad() && is_grad_enabled();
667            return if requires_grad {
668                let grad_fn = GradFn::new(LayerNormBackward::new(
669                    input.grad_fn().cloned(),
670                    self.weight.variable().grad_fn().cloned(),
671                    self.bias.variable().grad_fn().cloned(),
672                    input_data.clone(),
673                    self.weight.data().clone(),
674                    self.normalized_shape.clone(),
675                    self.eps,
676                ));
677                Variable::from_operation(output, grad_fn, true)
678            } else {
679                Variable::from_tensor(output)
680            };
681        }
682
683        // CPU path
684        let input_vec = input_data.to_vec();
685        let weight_vec = self.weight.data().to_vec();
686        let bias_vec = self.bias.data().to_vec();
687
688        let mut output_vec = vec![0.0f32; input_vec.len()];
689
690        for b in 0..num_rows {
691            let start = b * norm_size;
692            let end = start + norm_size;
693            let slice = &input_vec[start..end];
694
695            let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
696            let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
697
698            for i in 0..norm_size {
699                let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
700                output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
701            }
702        }
703
704        let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
705        let requires_grad = input.requires_grad() && is_grad_enabled();
706
707        if requires_grad {
708            let grad_fn = GradFn::new(LayerNormBackward::new(
709                input.grad_fn().cloned(),
710                self.weight.variable().grad_fn().cloned(),
711                self.bias.variable().grad_fn().cloned(),
712                input_data.clone(),
713                self.weight.data().clone(),
714                self.normalized_shape.clone(),
715                self.eps,
716            ));
717            Variable::from_operation(output, grad_fn, true)
718        } else {
719            Variable::from_tensor(output)
720        }
721    }
722
723    fn parameters(&self) -> Vec<Parameter> {
724        vec![self.weight.clone(), self.bias.clone()]
725    }
726
727    fn named_parameters(&self) -> HashMap<String, Parameter> {
728        let mut params = HashMap::new();
729        params.insert("weight".to_string(), self.weight.clone());
730        params.insert("bias".to_string(), self.bias.clone());
731        params
732    }
733
734    fn name(&self) -> &'static str {
735        "LayerNorm"
736    }
737}
738
739// =============================================================================
740// GroupNorm
741// =============================================================================
742
743/// Applies Group Normalization over a mini-batch of inputs.
744///
745/// Groups channels and normalizes within each group.
746/// Particularly effective for small batch sizes where BatchNorm struggles.
747///
748/// # Shape
749/// - Input: (N, C, *) where C must be divisible by num_groups
750/// - Output: Same as input
751pub struct GroupNorm {
752    /// Learnable scale parameter (gamma).
753    pub weight: Parameter,
754    /// Learnable shift parameter (beta).
755    pub bias: Parameter,
756    /// Number of groups to divide channels into.
757    num_groups: usize,
758    /// Number of channels expected in input.
759    num_channels: usize,
760    /// Epsilon for numerical stability.
761    eps: f32,
762    /// Whether to use learnable affine parameters.
763    affine: bool,
764}
765
766impl GroupNorm {
767    /// Creates a new GroupNorm layer.
768    ///
769    /// # Arguments
770    /// * `num_groups` - Number of groups to divide channels into
771    /// * `num_channels` - Number of channels expected in input
772    pub fn new(num_groups: usize, num_channels: usize) -> Self {
773        Self::with_options(num_groups, num_channels, 1e-5, true)
774    }
775
776    /// Creates a GroupNorm with custom options.
777    pub fn with_options(num_groups: usize, num_channels: usize, eps: f32, affine: bool) -> Self {
778        assert!(
779            num_channels % num_groups == 0,
780            "num_channels ({}) must be divisible by num_groups ({})",
781            num_channels,
782            num_groups
783        );
784
785        Self {
786            weight: Parameter::named("weight", ones(&[num_channels]), affine),
787            bias: Parameter::named("bias", zeros(&[num_channels]), affine),
788            num_groups,
789            num_channels,
790            eps,
791            affine,
792        }
793    }
794}
795
796impl Module for GroupNorm {
797    fn forward(&self, input: &Variable) -> Variable {
798        let input_data = input.data();
799        let shape = input_data.shape().to_vec();
800        let batch_size = shape[0];
801        let channels = shape[1];
802        let spatial_size: usize = shape[2..].iter().product();
803
804        assert_eq!(
805            channels, self.num_channels,
806            "GroupNorm: expected {} channels, got {}",
807            self.num_channels, channels
808        );
809
810        let input_vec = input_data.to_vec();
811        let channels_per_group = channels / self.num_groups;
812
813        let mut output_vec = vec![0.0f32; input_vec.len()];
814
815        for b in 0..batch_size {
816            for g in 0..self.num_groups {
817                // Calculate mean and variance for this group
818                let mut sum = 0.0f32;
819                let group_size = channels_per_group * spatial_size;
820
821                for c in 0..channels_per_group {
822                    let channel_idx = g * channels_per_group + c;
823                    for s in 0..spatial_size {
824                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
825                        sum += input_vec[idx];
826                    }
827                }
828                let mean = sum / group_size as f32;
829
830                let mut var_sum = 0.0f32;
831                for c in 0..channels_per_group {
832                    let channel_idx = g * channels_per_group + c;
833                    for s in 0..spatial_size {
834                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
835                        let diff = input_vec[idx] - mean;
836                        var_sum += diff * diff;
837                    }
838                }
839                let var = var_sum / group_size as f32;
840
841                // Normalize
842                let std_inv = 1.0 / (var + self.eps).sqrt();
843                for c in 0..channels_per_group {
844                    let channel_idx = g * channels_per_group + c;
845                    let weight = if self.affine {
846                        self.weight.data().to_vec()[channel_idx]
847                    } else {
848                        1.0
849                    };
850                    let bias = if self.affine {
851                        self.bias.data().to_vec()[channel_idx]
852                    } else {
853                        0.0
854                    };
855
856                    for s in 0..spatial_size {
857                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
858                        let normalized = (input_vec[idx] - mean) * std_inv;
859                        output_vec[idx] = normalized * weight + bias;
860                    }
861                }
862            }
863        }
864
865        let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
866        let requires_grad = input.requires_grad() && is_grad_enabled();
867
868        if requires_grad && self.affine {
869            let grad_fn = GradFn::new(GroupNormBackward::new(
870                input.grad_fn().cloned(),
871                self.weight.variable().grad_fn().cloned(),
872                self.bias.variable().grad_fn().cloned(),
873                input_data.clone(),
874                self.weight.data().clone(),
875                self.num_groups,
876                self.eps,
877            ));
878            Variable::from_operation(output, grad_fn, true)
879        } else {
880            Variable::from_tensor(output)
881        }
882    }
883
884    fn parameters(&self) -> Vec<Parameter> {
885        if self.affine {
886            vec![self.weight.clone(), self.bias.clone()]
887        } else {
888            vec![]
889        }
890    }
891
892    fn named_parameters(&self) -> HashMap<String, Parameter> {
893        if self.affine {
894            let mut params = HashMap::new();
895            params.insert("weight".to_string(), self.weight.clone());
896            params.insert("bias".to_string(), self.bias.clone());
897            params
898        } else {
899            HashMap::new()
900        }
901    }
902
903    fn name(&self) -> &'static str {
904        "GroupNorm"
905    }
906}
907
908// =============================================================================
909// InstanceNorm2d
910// =============================================================================
911
912/// Applies Instance Normalization over a 4D input (images).
913///
914/// Each channel in each sample is normalized independently.
915/// Particularly useful for style transfer and image generation.
916///
917/// # Shape
918/// - Input: (N, C, H, W)
919/// - Output: Same as input
920pub struct InstanceNorm2d {
921    /// Learnable scale parameter (gamma).
922    pub weight: Parameter,
923    /// Learnable shift parameter (beta).
924    pub bias: Parameter,
925    /// Number of features (channels).
926    num_features: usize,
927    /// Epsilon for numerical stability.
928    eps: f32,
929    /// Whether to use learnable affine parameters.
930    affine: bool,
931}
932
933impl InstanceNorm2d {
934    /// Creates a new InstanceNorm2d layer.
935    pub fn new(num_features: usize) -> Self {
936        Self::with_options(num_features, 1e-5, false)
937    }
938
939    /// Creates an InstanceNorm2d with affine parameters.
940    pub fn with_affine(num_features: usize) -> Self {
941        Self::with_options(num_features, 1e-5, true)
942    }
943
944    /// Creates an InstanceNorm2d with custom options.
945    pub fn with_options(num_features: usize, eps: f32, affine: bool) -> Self {
946        Self {
947            weight: Parameter::named("weight", ones(&[num_features]), affine),
948            bias: Parameter::named("bias", zeros(&[num_features]), affine),
949            num_features,
950            eps,
951            affine,
952        }
953    }
954}
955
956impl Module for InstanceNorm2d {
957    fn forward(&self, input: &Variable) -> Variable {
958        let input_data = input.data();
959        let shape = input_data.shape().to_vec();
960
961        assert!(
962            shape.len() == 4,
963            "InstanceNorm2d expects 4D input (N, C, H, W)"
964        );
965
966        let batch_size = shape[0];
967        let channels = shape[1];
968        let height = shape[2];
969        let width = shape[3];
970        let spatial_size = height * width;
971
972        assert_eq!(
973            channels, self.num_features,
974            "InstanceNorm2d: expected {} channels, got {}",
975            self.num_features, channels
976        );
977
978        let input_vec = input_data.to_vec();
979        let mut output_vec = vec![0.0f32; input_vec.len()];
980
981        for b in 0..batch_size {
982            for c in 0..channels {
983                // Calculate mean for this (batch, channel) pair
984                let mut sum = 0.0f32;
985                for s in 0..spatial_size {
986                    let idx = b * channels * spatial_size + c * spatial_size + s;
987                    sum += input_vec[idx];
988                }
989                let mean = sum / spatial_size as f32;
990
991                // Calculate variance
992                let mut var_sum = 0.0f32;
993                for s in 0..spatial_size {
994                    let idx = b * channels * spatial_size + c * spatial_size + s;
995                    let diff = input_vec[idx] - mean;
996                    var_sum += diff * diff;
997                }
998                let var = var_sum / spatial_size as f32;
999
1000                // Normalize and apply affine
1001                let std_inv = 1.0 / (var + self.eps).sqrt();
1002                let weight = if self.affine {
1003                    self.weight.data().to_vec()[c]
1004                } else {
1005                    1.0
1006                };
1007                let bias = if self.affine {
1008                    self.bias.data().to_vec()[c]
1009                } else {
1010                    0.0
1011                };
1012
1013                for s in 0..spatial_size {
1014                    let idx = b * channels * spatial_size + c * spatial_size + s;
1015                    let normalized = (input_vec[idx] - mean) * std_inv;
1016                    output_vec[idx] = normalized * weight + bias;
1017                }
1018            }
1019        }
1020
1021        let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
1022        let requires_grad = input.requires_grad() && is_grad_enabled();
1023
1024        if requires_grad {
1025            let grad_fn = GradFn::new(InstanceNorm2dBackward::new(
1026                input.grad_fn().cloned(),
1027                if self.affine {
1028                    self.weight.variable().grad_fn().cloned()
1029                } else {
1030                    None
1031                },
1032                if self.affine {
1033                    self.bias.variable().grad_fn().cloned()
1034                } else {
1035                    None
1036                },
1037                input_data.clone(),
1038                self.weight.data().clone(),
1039                self.eps,
1040                self.affine,
1041            ));
1042            Variable::from_operation(output, grad_fn, true)
1043        } else {
1044            Variable::from_tensor(output)
1045        }
1046    }
1047
1048    fn parameters(&self) -> Vec<Parameter> {
1049        if self.affine {
1050            vec![self.weight.clone(), self.bias.clone()]
1051        } else {
1052            vec![]
1053        }
1054    }
1055
1056    fn named_parameters(&self) -> HashMap<String, Parameter> {
1057        if self.affine {
1058            let mut params = HashMap::new();
1059            params.insert("weight".to_string(), self.weight.clone());
1060            params.insert("bias".to_string(), self.bias.clone());
1061            params
1062        } else {
1063            HashMap::new()
1064        }
1065    }
1066
1067    fn name(&self) -> &'static str {
1068        "InstanceNorm2d"
1069    }
1070}
1071
1072// =============================================================================
1073// Tests
1074// =============================================================================
1075
1076#[cfg(test)]
1077mod tests {
1078    use super::*;
1079
1080    #[test]
1081    fn test_batchnorm1d() {
1082        let bn = BatchNorm1d::new(3);
1083        let input = Variable::new(
1084            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).expect("tensor creation failed"),
1085            false,
1086        );
1087        let output = bn.forward(&input);
1088        assert_eq!(output.shape(), vec![2, 3]);
1089    }
1090
1091    #[test]
1092    fn test_batchnorm2d() {
1093        let bn = BatchNorm2d::new(2);
1094        let input = Variable::new(
1095            Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).expect("tensor creation failed"),
1096            false,
1097        );
1098        let output = bn.forward(&input);
1099        assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1100    }
1101
1102    #[test]
1103    fn test_layernorm() {
1104        let ln = LayerNorm::single(4);
1105        let input = Variable::new(
1106            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).expect("tensor creation failed"),
1107            false,
1108        );
1109        let output = ln.forward(&input);
1110        assert_eq!(output.shape(), vec![2, 4]);
1111    }
1112
1113    #[test]
1114    fn test_batchnorm_parameters() {
1115        let bn = BatchNorm1d::new(10);
1116        assert_eq!(bn.parameters().len(), 2);
1117        assert_eq!(bn.num_parameters(), 20); // weight + bias
1118    }
1119
1120    #[test]
1121    fn test_groupnorm() {
1122        let gn = GroupNorm::new(2, 4); // 2 groups, 4 channels
1123        let input = Variable::new(
1124            Tensor::from_vec(vec![1.0; 32], &[2, 4, 2, 2]).expect("tensor creation failed"),
1125            false,
1126        );
1127        let output = gn.forward(&input);
1128        assert_eq!(output.shape(), vec![2, 4, 2, 2]);
1129    }
1130
1131    #[test]
1132    fn test_groupnorm_normalization() {
1133        let gn = GroupNorm::with_options(2, 4, 1e-5, false); // No affine
1134        let input = Variable::new(
1135            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 4, 1, 2]).expect("tensor creation failed"),
1136            false,
1137        );
1138        let output = gn.forward(&input);
1139        // After normalization within groups, should have zero mean
1140        let out_vec = output.data().to_vec();
1141        // Group 1: channels 0,1 (vals 1,2,3,4) and Group 2: channels 2,3 (vals 5,6,7,8)
1142        let group1_mean: f32 = out_vec[0..4].iter().sum::<f32>() / 4.0;
1143        let group2_mean: f32 = out_vec[4..8].iter().sum::<f32>() / 4.0;
1144        assert!(group1_mean.abs() < 1e-5);
1145        assert!(group2_mean.abs() < 1e-5);
1146    }
1147
1148    #[test]
1149    fn test_instancenorm2d() {
1150        let inn = InstanceNorm2d::new(2);
1151        let input = Variable::new(
1152            Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).expect("tensor creation failed"),
1153            false,
1154        );
1155        let output = inn.forward(&input);
1156        assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1157    }
1158
1159    #[test]
1160    fn test_instancenorm2d_with_affine() {
1161        let inn = InstanceNorm2d::with_affine(4);
1162        let input = Variable::new(
1163            Tensor::from_vec(vec![1.0; 64], &[1, 4, 4, 4]).expect("tensor creation failed"),
1164            false,
1165        );
1166        let output = inn.forward(&input);
1167        assert_eq!(output.shape(), vec![1, 4, 4, 4]);
1168        assert_eq!(inn.parameters().len(), 2);
1169    }
1170}