Skip to main content

axonml_nn/layers/
norm.rs

1//! Normalization Layers
2//!
3//! Provides normalization layers to improve training stability and convergence speed.
4//!
5//! # Available Layers
6//!
7//! - **BatchNorm1d/2d** - Normalizes over batch dimension (good for large batches)
8//! - **LayerNorm** - Normalizes over feature dimension (stable for any batch size)
9//! - **GroupNorm** - Normalizes within channel groups (good for small batches, used in diffusion models)
10//! - **InstanceNorm2d** - Normalizes each (batch, channel) independently (good for style transfer)
11//!
12//! # When to Use Which
13//!
14//! | Layer | Best For | Batch Size Dependency |
15//! |-------|----------|----------------------|
16//! | BatchNorm | CNNs, large batch training | Requires large batches |
17//! | LayerNorm | Transformers, RNNs | Batch-independent |
18//! | GroupNorm | ResNeXt, diffusion models | Batch-independent |
19//! | InstanceNorm | Style transfer, GANs | Batch-independent |
20//!
21//! # Example
22//!
23//! ```ignore
24//! use axonml_nn::{GroupNorm, InstanceNorm2d, Module};
25//!
26//! // GroupNorm: 8 groups, 32 channels
27//! let gn = GroupNorm::new(8, 32);
28//! let output = gn.forward(&input); // [N, 32, H, W]
29//!
30//! // InstanceNorm: normalize each channel independently
31//! let inn = InstanceNorm2d::with_affine(64);
32//! let output = inn.forward(&input); // [N, 64, H, W]
33//! ```
34//!
35//! @version 0.2.6
36//! @author AutomataNexus Development Team
37
38use std::collections::HashMap;
39use std::sync::atomic::{AtomicBool, Ordering};
40
41use axonml_autograd::Variable;
42use axonml_tensor::Tensor;
43use parking_lot::RwLock;
44
45use crate::init::{ones, zeros};
46use crate::module::Module;
47use crate::parameter::Parameter;
48
49// =============================================================================
50// BatchNorm1d
51// =============================================================================
52
53/// Applies Batch Normalization over a 2D or 3D input.
54///
55/// y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
56///
57/// # Shape
58/// - Input: (N, C) or (N, C, L)
59/// - Output: Same as input
60pub struct BatchNorm1d {
61    /// Learnable scale parameter (gamma).
62    pub weight: Parameter,
63    /// Learnable shift parameter (beta).
64    pub bias: Parameter,
65    /// Running mean for inference (updated during training).
66    running_mean: RwLock<Tensor<f32>>,
67    /// Running variance for inference (updated during training).
68    running_var: RwLock<Tensor<f32>>,
69    /// Number of features.
70    num_features: usize,
71    /// Epsilon for numerical stability.
72    eps: f32,
73    /// Momentum for running stats update: running = (1 - momentum) * running + momentum * batch.
74    momentum: f32,
75    /// Whether to track running stats.
76    track_running_stats: bool,
77    /// Whether in training mode.
78    training: AtomicBool,
79}
80
81impl BatchNorm1d {
82    /// Creates a new BatchNorm1d layer.
83    pub fn new(num_features: usize) -> Self {
84        Self::with_options(num_features, 1e-5, 0.1, true)
85    }
86
87    /// Creates a BatchNorm1d with custom options.
88    pub fn with_options(
89        num_features: usize,
90        eps: f32,
91        momentum: f32,
92        track_running_stats: bool,
93    ) -> Self {
94        Self {
95            weight: Parameter::named("weight", ones(&[num_features]), true),
96            bias: Parameter::named("bias", zeros(&[num_features]), true),
97            running_mean: RwLock::new(zeros(&[num_features])),
98            running_var: RwLock::new(ones(&[num_features])),
99            num_features,
100            eps,
101            momentum,
102            track_running_stats,
103            training: AtomicBool::new(true),
104        }
105    }
106
107    /// Returns the number of features.
108    pub fn num_features(&self) -> usize {
109        self.num_features
110    }
111}
112
113impl Module for BatchNorm1d {
114    fn forward(&self, input: &Variable) -> Variable {
115        let input_data = input.data();
116        let shape = input_data.shape().to_vec();
117        let batch_size = shape[0];
118        let num_features = shape[1];
119
120        // Validate input matches expected features
121        assert_eq!(
122            num_features, self.num_features,
123            "BatchNorm1d: expected {} features, got {}",
124            self.num_features, num_features
125        );
126
127        let input_vec = input_data.to_vec();
128        let weight_vec = self.weight.data().to_vec();
129        let bias_vec = self.bias.data().to_vec();
130
131        let is_training = self.training.load(Ordering::Relaxed);
132        let spatial_size: usize = if shape.len() > 2 {
133            shape[2..].iter().product()
134        } else {
135            1
136        };
137
138        let mut means = vec![0.0f32; num_features];
139        let mut vars = vec![0.0f32; num_features];
140
141        if is_training {
142            // Calculate batch statistics
143            for c in 0..num_features {
144                let mut sum = 0.0f32;
145                for b in 0..batch_size {
146                    for s in 0..spatial_size {
147                        let idx = b * num_features * spatial_size + c * spatial_size + s;
148                        sum += input_vec[idx];
149                    }
150                }
151                means[c] = sum / (batch_size * spatial_size) as f32;
152
153                let mut var_sum = 0.0f32;
154                for b in 0..batch_size {
155                    for s in 0..spatial_size {
156                        let idx = b * num_features * spatial_size + c * spatial_size + s;
157                        let diff = input_vec[idx] - means[c];
158                        var_sum += diff * diff;
159                    }
160                }
161                vars[c] = var_sum / (batch_size * spatial_size) as f32;
162            }
163
164            // Update running statistics if tracking is enabled
165            if self.track_running_stats {
166                let mut running_mean = self.running_mean.write();
167                let mut running_var = self.running_var.write();
168                let running_mean_vec = running_mean.to_vec();
169                let running_var_vec = running_var.to_vec();
170
171                let new_mean: Vec<f32> = running_mean_vec
172                    .iter()
173                    .zip(means.iter())
174                    .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
175                    .collect();
176                let new_var: Vec<f32> = running_var_vec
177                    .iter()
178                    .zip(vars.iter())
179                    .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
180                    .collect();
181
182                *running_mean = Tensor::from_vec(new_mean, &[num_features]).unwrap();
183                *running_var = Tensor::from_vec(new_var, &[num_features]).unwrap();
184            }
185        } else {
186            // Use running statistics for inference
187            means = self.running_mean.read().to_vec();
188            vars = self.running_var.read().to_vec();
189        }
190
191        // Normalize: y = (x - mean) / sqrt(var + eps) * weight + bias
192        let mut output_vec = vec![0.0f32; input_vec.len()];
193        for b in 0..batch_size {
194            for c in 0..num_features {
195                for s in 0..spatial_size {
196                    let idx = b * num_features * spatial_size + c * spatial_size + s;
197                    let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
198                    output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
199                }
200            }
201        }
202
203        let output = Tensor::from_vec(output_vec, &shape).unwrap();
204        Variable::new(output, input.requires_grad())
205    }
206
207    fn parameters(&self) -> Vec<Parameter> {
208        vec![self.weight.clone(), self.bias.clone()]
209    }
210
211    fn named_parameters(&self) -> HashMap<String, Parameter> {
212        let mut params = HashMap::new();
213        params.insert("weight".to_string(), self.weight.clone());
214        params.insert("bias".to_string(), self.bias.clone());
215        params
216    }
217
218    fn set_training(&mut self, training: bool) {
219        self.training.store(training, Ordering::Relaxed);
220    }
221
222    fn is_training(&self) -> bool {
223        self.training.load(Ordering::Relaxed)
224    }
225
226    fn name(&self) -> &'static str {
227        "BatchNorm1d"
228    }
229}
230
231// =============================================================================
232// BatchNorm2d
233// =============================================================================
234
235/// Applies Batch Normalization over a 4D input (images).
236///
237/// # Shape
238/// - Input: (N, C, H, W)
239/// - Output: Same as input
240pub struct BatchNorm2d {
241    /// Learnable scale parameter (gamma).
242    pub weight: Parameter,
243    /// Learnable shift parameter (beta).
244    pub bias: Parameter,
245    /// Running mean for inference (updated during training).
246    running_mean: RwLock<Tensor<f32>>,
247    /// Running variance for inference (updated during training).
248    running_var: RwLock<Tensor<f32>>,
249    /// Number of features (channels).
250    num_features: usize,
251    /// Epsilon for numerical stability.
252    eps: f32,
253    /// Momentum for running stats update.
254    momentum: f32,
255    /// Whether in training mode.
256    training: AtomicBool,
257}
258
259impl BatchNorm2d {
260    /// Creates a new BatchNorm2d layer.
261    pub fn new(num_features: usize) -> Self {
262        Self::with_options(num_features, 1e-5, 0.1)
263    }
264
265    /// Creates a BatchNorm2d with custom options.
266    pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
267        Self {
268            weight: Parameter::named("weight", ones(&[num_features]), true),
269            bias: Parameter::named("bias", zeros(&[num_features]), true),
270            running_mean: RwLock::new(zeros(&[num_features])),
271            running_var: RwLock::new(ones(&[num_features])),
272            num_features,
273            eps,
274            momentum,
275            training: AtomicBool::new(true),
276        }
277    }
278
279    /// Returns the number of features (channels).
280    pub fn num_features(&self) -> usize {
281        self.num_features
282    }
283}
284
285impl Module for BatchNorm2d {
286    fn forward(&self, input: &Variable) -> Variable {
287        let input_data = input.data();
288        let shape = input_data.shape().to_vec();
289        let batch_size = shape[0];
290        let channels = shape[1];
291        let height = shape[2];
292        let width = shape[3];
293        let spatial_size = height * width;
294
295        // Validate input matches expected channels
296        assert_eq!(
297            channels, self.num_features,
298            "BatchNorm2d: expected {} channels, got {}",
299            self.num_features, channels
300        );
301
302        let input_vec = input_data.to_vec();
303        let weight_vec = self.weight.data().to_vec();
304        let bias_vec = self.bias.data().to_vec();
305
306        let is_training = self.training.load(Ordering::Relaxed);
307
308        let mut means = vec![0.0f32; channels];
309        let mut vars = vec![0.0f32; channels];
310
311        if is_training {
312            for c in 0..channels {
313                let mut sum = 0.0f32;
314                for b in 0..batch_size {
315                    for h in 0..height {
316                        for w in 0..width {
317                            let idx =
318                                b * channels * spatial_size + c * spatial_size + h * width + w;
319                            sum += input_vec[idx];
320                        }
321                    }
322                }
323                means[c] = sum / (batch_size * spatial_size) as f32;
324
325                let mut var_sum = 0.0f32;
326                for b in 0..batch_size {
327                    for h in 0..height {
328                        for w in 0..width {
329                            let idx =
330                                b * channels * spatial_size + c * spatial_size + h * width + w;
331                            let diff = input_vec[idx] - means[c];
332                            var_sum += diff * diff;
333                        }
334                    }
335                }
336                vars[c] = var_sum / (batch_size * spatial_size) as f32;
337            }
338
339            // Update running statistics
340            let mut running_mean = self.running_mean.write();
341            let mut running_var = self.running_var.write();
342            let running_mean_vec = running_mean.to_vec();
343            let running_var_vec = running_var.to_vec();
344
345            let new_mean: Vec<f32> = running_mean_vec
346                .iter()
347                .zip(means.iter())
348                .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
349                .collect();
350            let new_var: Vec<f32> = running_var_vec
351                .iter()
352                .zip(vars.iter())
353                .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
354                .collect();
355
356            *running_mean = Tensor::from_vec(new_mean, &[channels]).unwrap();
357            *running_var = Tensor::from_vec(new_var, &[channels]).unwrap();
358        } else {
359            means = self.running_mean.read().to_vec();
360            vars = self.running_var.read().to_vec();
361        }
362
363        let mut output_vec = vec![0.0f32; input_vec.len()];
364        for b in 0..batch_size {
365            for c in 0..channels {
366                for h in 0..height {
367                    for w in 0..width {
368                        let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
369                        let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
370                        output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
371                    }
372                }
373            }
374        }
375
376        let output = Tensor::from_vec(output_vec, &shape).unwrap();
377        Variable::new(output, input.requires_grad())
378    }
379
380    fn parameters(&self) -> Vec<Parameter> {
381        vec![self.weight.clone(), self.bias.clone()]
382    }
383
384    fn named_parameters(&self) -> HashMap<String, Parameter> {
385        let mut params = HashMap::new();
386        params.insert("weight".to_string(), self.weight.clone());
387        params.insert("bias".to_string(), self.bias.clone());
388        params
389    }
390
391    fn set_training(&mut self, training: bool) {
392        self.training.store(training, Ordering::Relaxed);
393    }
394
395    fn is_training(&self) -> bool {
396        self.training.load(Ordering::Relaxed)
397    }
398
399    fn name(&self) -> &'static str {
400        "BatchNorm2d"
401    }
402}
403
404// =============================================================================
405// LayerNorm
406// =============================================================================
407
408/// Applies Layer Normalization over the last D dimensions.
409///
410/// y = (x - E[x]) / sqrt(Var[x] + eps) * gamma + beta
411///
412/// Unlike BatchNorm, LayerNorm normalizes over features, not batch.
413pub struct LayerNorm {
414    /// Learnable scale parameter (gamma).
415    pub weight: Parameter,
416    /// Learnable shift parameter (beta).
417    pub bias: Parameter,
418    /// Normalized shape.
419    normalized_shape: Vec<usize>,
420    /// Epsilon for numerical stability.
421    eps: f32,
422}
423
424impl LayerNorm {
425    /// Creates a new LayerNorm layer.
426    pub fn new(normalized_shape: Vec<usize>) -> Self {
427        Self::with_eps(normalized_shape, 1e-5)
428    }
429
430    /// Creates a LayerNorm for a single dimension.
431    pub fn single(size: usize) -> Self {
432        Self::new(vec![size])
433    }
434
435    /// Creates a LayerNorm with custom epsilon.
436    pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
437        let numel: usize = normalized_shape.iter().product();
438        Self {
439            weight: Parameter::named("weight", ones(&[numel]), true),
440            bias: Parameter::named("bias", zeros(&[numel]), true),
441            normalized_shape,
442            eps,
443        }
444    }
445}
446
447impl Module for LayerNorm {
448    fn forward(&self, input: &Variable) -> Variable {
449        let input_data = input.data();
450        let shape = input_data.shape().to_vec();
451        let input_vec = input_data.to_vec();
452
453        let weight_vec = self.weight.data().to_vec();
454        let bias_vec = self.bias.data().to_vec();
455
456        // Calculate the size of the normalized dimensions
457        let norm_size: usize = self.normalized_shape.iter().product();
458        let batch_size = input_vec.len() / norm_size;
459
460        let mut output_vec = vec![0.0f32; input_vec.len()];
461
462        for b in 0..batch_size {
463            let start = b * norm_size;
464            let end = start + norm_size;
465            let slice = &input_vec[start..end];
466
467            // Calculate mean
468            let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
469
470            // Calculate variance
471            let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
472
473            // Normalize and apply affine transform
474            for i in 0..norm_size {
475                let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
476                output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
477            }
478        }
479
480        let output = Tensor::from_vec(output_vec, &shape).unwrap();
481        Variable::new(output, input.requires_grad())
482    }
483
484    fn parameters(&self) -> Vec<Parameter> {
485        vec![self.weight.clone(), self.bias.clone()]
486    }
487
488    fn named_parameters(&self) -> HashMap<String, Parameter> {
489        let mut params = HashMap::new();
490        params.insert("weight".to_string(), self.weight.clone());
491        params.insert("bias".to_string(), self.bias.clone());
492        params
493    }
494
495    fn name(&self) -> &'static str {
496        "LayerNorm"
497    }
498}
499
500// =============================================================================
501// GroupNorm
502// =============================================================================
503
504/// Applies Group Normalization over a mini-batch of inputs.
505///
506/// Groups channels and normalizes within each group.
507/// Particularly effective for small batch sizes where BatchNorm struggles.
508///
509/// # Shape
510/// - Input: (N, C, *) where C must be divisible by num_groups
511/// - Output: Same as input
512pub struct GroupNorm {
513    /// Learnable scale parameter (gamma).
514    pub weight: Parameter,
515    /// Learnable shift parameter (beta).
516    pub bias: Parameter,
517    /// Number of groups to divide channels into.
518    num_groups: usize,
519    /// Number of channels expected in input.
520    num_channels: usize,
521    /// Epsilon for numerical stability.
522    eps: f32,
523    /// Whether to use learnable affine parameters.
524    affine: bool,
525}
526
527impl GroupNorm {
528    /// Creates a new GroupNorm layer.
529    ///
530    /// # Arguments
531    /// * `num_groups` - Number of groups to divide channels into
532    /// * `num_channels` - Number of channels expected in input
533    pub fn new(num_groups: usize, num_channels: usize) -> Self {
534        Self::with_options(num_groups, num_channels, 1e-5, true)
535    }
536
537    /// Creates a GroupNorm with custom options.
538    pub fn with_options(num_groups: usize, num_channels: usize, eps: f32, affine: bool) -> Self {
539        assert!(
540            num_channels % num_groups == 0,
541            "num_channels ({}) must be divisible by num_groups ({})",
542            num_channels,
543            num_groups
544        );
545
546        Self {
547            weight: Parameter::named("weight", ones(&[num_channels]), affine),
548            bias: Parameter::named("bias", zeros(&[num_channels]), affine),
549            num_groups,
550            num_channels,
551            eps,
552            affine,
553        }
554    }
555}
556
557impl Module for GroupNorm {
558    fn forward(&self, input: &Variable) -> Variable {
559        let input_data = input.data();
560        let shape = input_data.shape().to_vec();
561        let batch_size = shape[0];
562        let channels = shape[1];
563        let spatial_size: usize = shape[2..].iter().product();
564
565        assert_eq!(
566            channels, self.num_channels,
567            "GroupNorm: expected {} channels, got {}",
568            self.num_channels, channels
569        );
570
571        let input_vec = input_data.to_vec();
572        let channels_per_group = channels / self.num_groups;
573
574        let mut output_vec = vec![0.0f32; input_vec.len()];
575
576        for b in 0..batch_size {
577            for g in 0..self.num_groups {
578                // Calculate mean and variance for this group
579                let mut sum = 0.0f32;
580                let group_size = channels_per_group * spatial_size;
581
582                for c in 0..channels_per_group {
583                    let channel_idx = g * channels_per_group + c;
584                    for s in 0..spatial_size {
585                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
586                        sum += input_vec[idx];
587                    }
588                }
589                let mean = sum / group_size as f32;
590
591                let mut var_sum = 0.0f32;
592                for c in 0..channels_per_group {
593                    let channel_idx = g * channels_per_group + c;
594                    for s in 0..spatial_size {
595                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
596                        let diff = input_vec[idx] - mean;
597                        var_sum += diff * diff;
598                    }
599                }
600                let var = var_sum / group_size as f32;
601
602                // Normalize
603                let std_inv = 1.0 / (var + self.eps).sqrt();
604                for c in 0..channels_per_group {
605                    let channel_idx = g * channels_per_group + c;
606                    let weight = if self.affine {
607                        self.weight.data().to_vec()[channel_idx]
608                    } else {
609                        1.0
610                    };
611                    let bias = if self.affine {
612                        self.bias.data().to_vec()[channel_idx]
613                    } else {
614                        0.0
615                    };
616
617                    for s in 0..spatial_size {
618                        let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
619                        let normalized = (input_vec[idx] - mean) * std_inv;
620                        output_vec[idx] = normalized * weight + bias;
621                    }
622                }
623            }
624        }
625
626        let output = Tensor::from_vec(output_vec, &shape).unwrap();
627        Variable::new(output, input.requires_grad())
628    }
629
630    fn parameters(&self) -> Vec<Parameter> {
631        if self.affine {
632            vec![self.weight.clone(), self.bias.clone()]
633        } else {
634            vec![]
635        }
636    }
637
638    fn named_parameters(&self) -> HashMap<String, Parameter> {
639        if self.affine {
640            let mut params = HashMap::new();
641            params.insert("weight".to_string(), self.weight.clone());
642            params.insert("bias".to_string(), self.bias.clone());
643            params
644        } else {
645            HashMap::new()
646        }
647    }
648
649    fn name(&self) -> &'static str {
650        "GroupNorm"
651    }
652}
653
654// =============================================================================
655// InstanceNorm2d
656// =============================================================================
657
658/// Applies Instance Normalization over a 4D input (images).
659///
660/// Each channel in each sample is normalized independently.
661/// Particularly useful for style transfer and image generation.
662///
663/// # Shape
664/// - Input: (N, C, H, W)
665/// - Output: Same as input
666pub struct InstanceNorm2d {
667    /// Learnable scale parameter (gamma).
668    pub weight: Parameter,
669    /// Learnable shift parameter (beta).
670    pub bias: Parameter,
671    /// Number of features (channels).
672    num_features: usize,
673    /// Epsilon for numerical stability.
674    eps: f32,
675    /// Whether to use learnable affine parameters.
676    affine: bool,
677}
678
679impl InstanceNorm2d {
680    /// Creates a new InstanceNorm2d layer.
681    pub fn new(num_features: usize) -> Self {
682        Self::with_options(num_features, 1e-5, false)
683    }
684
685    /// Creates an InstanceNorm2d with affine parameters.
686    pub fn with_affine(num_features: usize) -> Self {
687        Self::with_options(num_features, 1e-5, true)
688    }
689
690    /// Creates an InstanceNorm2d with custom options.
691    pub fn with_options(num_features: usize, eps: f32, affine: bool) -> Self {
692        Self {
693            weight: Parameter::named("weight", ones(&[num_features]), affine),
694            bias: Parameter::named("bias", zeros(&[num_features]), affine),
695            num_features,
696            eps,
697            affine,
698        }
699    }
700}
701
702impl Module for InstanceNorm2d {
703    fn forward(&self, input: &Variable) -> Variable {
704        let input_data = input.data();
705        let shape = input_data.shape().to_vec();
706
707        assert!(
708            shape.len() == 4,
709            "InstanceNorm2d expects 4D input (N, C, H, W)"
710        );
711
712        let batch_size = shape[0];
713        let channels = shape[1];
714        let height = shape[2];
715        let width = shape[3];
716        let spatial_size = height * width;
717
718        assert_eq!(
719            channels, self.num_features,
720            "InstanceNorm2d: expected {} channels, got {}",
721            self.num_features, channels
722        );
723
724        let input_vec = input_data.to_vec();
725        let mut output_vec = vec![0.0f32; input_vec.len()];
726
727        for b in 0..batch_size {
728            for c in 0..channels {
729                // Calculate mean for this (batch, channel) pair
730                let mut sum = 0.0f32;
731                for s in 0..spatial_size {
732                    let idx = b * channels * spatial_size + c * spatial_size + s;
733                    sum += input_vec[idx];
734                }
735                let mean = sum / spatial_size as f32;
736
737                // Calculate variance
738                let mut var_sum = 0.0f32;
739                for s in 0..spatial_size {
740                    let idx = b * channels * spatial_size + c * spatial_size + s;
741                    let diff = input_vec[idx] - mean;
742                    var_sum += diff * diff;
743                }
744                let var = var_sum / spatial_size as f32;
745
746                // Normalize and apply affine
747                let std_inv = 1.0 / (var + self.eps).sqrt();
748                let weight = if self.affine {
749                    self.weight.data().to_vec()[c]
750                } else {
751                    1.0
752                };
753                let bias = if self.affine {
754                    self.bias.data().to_vec()[c]
755                } else {
756                    0.0
757                };
758
759                for s in 0..spatial_size {
760                    let idx = b * channels * spatial_size + c * spatial_size + s;
761                    let normalized = (input_vec[idx] - mean) * std_inv;
762                    output_vec[idx] = normalized * weight + bias;
763                }
764            }
765        }
766
767        let output = Tensor::from_vec(output_vec, &shape).unwrap();
768        Variable::new(output, input.requires_grad())
769    }
770
771    fn parameters(&self) -> Vec<Parameter> {
772        if self.affine {
773            vec![self.weight.clone(), self.bias.clone()]
774        } else {
775            vec![]
776        }
777    }
778
779    fn named_parameters(&self) -> HashMap<String, Parameter> {
780        if self.affine {
781            let mut params = HashMap::new();
782            params.insert("weight".to_string(), self.weight.clone());
783            params.insert("bias".to_string(), self.bias.clone());
784            params
785        } else {
786            HashMap::new()
787        }
788    }
789
790    fn name(&self) -> &'static str {
791        "InstanceNorm2d"
792    }
793}
794
795// =============================================================================
796// Tests
797// =============================================================================
798
799#[cfg(test)]
800mod tests {
801    use super::*;
802
803    #[test]
804    fn test_batchnorm1d() {
805        let bn = BatchNorm1d::new(3);
806        let input = Variable::new(
807            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
808            false,
809        );
810        let output = bn.forward(&input);
811        assert_eq!(output.shape(), vec![2, 3]);
812    }
813
814    #[test]
815    fn test_batchnorm2d() {
816        let bn = BatchNorm2d::new(2);
817        let input = Variable::new(
818            Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
819            false,
820        );
821        let output = bn.forward(&input);
822        assert_eq!(output.shape(), vec![2, 2, 2, 4]);
823    }
824
825    #[test]
826    fn test_layernorm() {
827        let ln = LayerNorm::single(4);
828        let input = Variable::new(
829            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap(),
830            false,
831        );
832        let output = ln.forward(&input);
833        assert_eq!(output.shape(), vec![2, 4]);
834    }
835
836    #[test]
837    fn test_batchnorm_parameters() {
838        let bn = BatchNorm1d::new(10);
839        assert_eq!(bn.parameters().len(), 2);
840        assert_eq!(bn.num_parameters(), 20); // weight + bias
841    }
842
843    #[test]
844    fn test_groupnorm() {
845        let gn = GroupNorm::new(2, 4); // 2 groups, 4 channels
846        let input = Variable::new(
847            Tensor::from_vec(vec![1.0; 32], &[2, 4, 2, 2]).unwrap(),
848            false,
849        );
850        let output = gn.forward(&input);
851        assert_eq!(output.shape(), vec![2, 4, 2, 2]);
852    }
853
854    #[test]
855    fn test_groupnorm_normalization() {
856        let gn = GroupNorm::with_options(2, 4, 1e-5, false); // No affine
857        let input = Variable::new(
858            Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 4, 1, 2]).unwrap(),
859            false,
860        );
861        let output = gn.forward(&input);
862        // After normalization within groups, should have zero mean
863        let out_vec = output.data().to_vec();
864        // Group 1: channels 0,1 (vals 1,2,3,4) and Group 2: channels 2,3 (vals 5,6,7,8)
865        let group1_mean: f32 = out_vec[0..4].iter().sum::<f32>() / 4.0;
866        let group2_mean: f32 = out_vec[4..8].iter().sum::<f32>() / 4.0;
867        assert!(group1_mean.abs() < 1e-5);
868        assert!(group2_mean.abs() < 1e-5);
869    }
870
871    #[test]
872    fn test_instancenorm2d() {
873        let inn = InstanceNorm2d::new(2);
874        let input = Variable::new(
875            Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
876            false,
877        );
878        let output = inn.forward(&input);
879        assert_eq!(output.shape(), vec![2, 2, 2, 4]);
880    }
881
882    #[test]
883    fn test_instancenorm2d_with_affine() {
884        let inn = InstanceNorm2d::with_affine(4);
885        let input = Variable::new(
886            Tensor::from_vec(vec![1.0; 64], &[1, 4, 4, 4]).unwrap(),
887            false,
888        );
889        let output = inn.forward(&input);
890        assert_eq!(output.shape(), vec![1, 4, 4, 4]);
891        assert_eq!(inn.parameters().len(), 2);
892    }
893}