1use std::collections::HashMap;
39use std::sync::atomic::{AtomicBool, Ordering};
40
41use axonml_autograd::Variable;
42use axonml_tensor::Tensor;
43use parking_lot::RwLock;
44
45use crate::init::{ones, zeros};
46use crate::module::Module;
47use crate::parameter::Parameter;
48
49pub struct BatchNorm1d {
61 pub weight: Parameter,
63 pub bias: Parameter,
65 running_mean: RwLock<Tensor<f32>>,
67 running_var: RwLock<Tensor<f32>>,
69 num_features: usize,
71 eps: f32,
73 momentum: f32,
75 track_running_stats: bool,
77 training: AtomicBool,
79}
80
81impl BatchNorm1d {
82 pub fn new(num_features: usize) -> Self {
84 Self::with_options(num_features, 1e-5, 0.1, true)
85 }
86
87 pub fn with_options(
89 num_features: usize,
90 eps: f32,
91 momentum: f32,
92 track_running_stats: bool,
93 ) -> Self {
94 Self {
95 weight: Parameter::named("weight", ones(&[num_features]), true),
96 bias: Parameter::named("bias", zeros(&[num_features]), true),
97 running_mean: RwLock::new(zeros(&[num_features])),
98 running_var: RwLock::new(ones(&[num_features])),
99 num_features,
100 eps,
101 momentum,
102 track_running_stats,
103 training: AtomicBool::new(true),
104 }
105 }
106
107 pub fn num_features(&self) -> usize {
109 self.num_features
110 }
111}
112
113impl Module for BatchNorm1d {
114 fn forward(&self, input: &Variable) -> Variable {
115 let input_data = input.data();
116 let shape = input_data.shape().to_vec();
117 let batch_size = shape[0];
118 let num_features = shape[1];
119
120 assert_eq!(
122 num_features, self.num_features,
123 "BatchNorm1d: expected {} features, got {}",
124 self.num_features, num_features
125 );
126
127 let input_vec = input_data.to_vec();
128 let weight_vec = self.weight.data().to_vec();
129 let bias_vec = self.bias.data().to_vec();
130
131 let is_training = self.training.load(Ordering::Relaxed);
132 let spatial_size: usize = if shape.len() > 2 {
133 shape[2..].iter().product()
134 } else {
135 1
136 };
137
138 let mut means = vec![0.0f32; num_features];
139 let mut vars = vec![0.0f32; num_features];
140
141 if is_training {
142 for c in 0..num_features {
144 let mut sum = 0.0f32;
145 for b in 0..batch_size {
146 for s in 0..spatial_size {
147 let idx = b * num_features * spatial_size + c * spatial_size + s;
148 sum += input_vec[idx];
149 }
150 }
151 means[c] = sum / (batch_size * spatial_size) as f32;
152
153 let mut var_sum = 0.0f32;
154 for b in 0..batch_size {
155 for s in 0..spatial_size {
156 let idx = b * num_features * spatial_size + c * spatial_size + s;
157 let diff = input_vec[idx] - means[c];
158 var_sum += diff * diff;
159 }
160 }
161 vars[c] = var_sum / (batch_size * spatial_size) as f32;
162 }
163
164 if self.track_running_stats {
166 let mut running_mean = self.running_mean.write();
167 let mut running_var = self.running_var.write();
168 let running_mean_vec = running_mean.to_vec();
169 let running_var_vec = running_var.to_vec();
170
171 let new_mean: Vec<f32> = running_mean_vec
172 .iter()
173 .zip(means.iter())
174 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
175 .collect();
176 let new_var: Vec<f32> = running_var_vec
177 .iter()
178 .zip(vars.iter())
179 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
180 .collect();
181
182 *running_mean = Tensor::from_vec(new_mean, &[num_features]).unwrap();
183 *running_var = Tensor::from_vec(new_var, &[num_features]).unwrap();
184 }
185 } else {
186 means = self.running_mean.read().to_vec();
188 vars = self.running_var.read().to_vec();
189 }
190
191 let mut output_vec = vec![0.0f32; input_vec.len()];
193 for b in 0..batch_size {
194 for c in 0..num_features {
195 for s in 0..spatial_size {
196 let idx = b * num_features * spatial_size + c * spatial_size + s;
197 let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
198 output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
199 }
200 }
201 }
202
203 let output = Tensor::from_vec(output_vec, &shape).unwrap();
204 Variable::new(output, input.requires_grad())
205 }
206
207 fn parameters(&self) -> Vec<Parameter> {
208 vec![self.weight.clone(), self.bias.clone()]
209 }
210
211 fn named_parameters(&self) -> HashMap<String, Parameter> {
212 let mut params = HashMap::new();
213 params.insert("weight".to_string(), self.weight.clone());
214 params.insert("bias".to_string(), self.bias.clone());
215 params
216 }
217
218 fn set_training(&mut self, training: bool) {
219 self.training.store(training, Ordering::Relaxed);
220 }
221
222 fn is_training(&self) -> bool {
223 self.training.load(Ordering::Relaxed)
224 }
225
226 fn name(&self) -> &'static str {
227 "BatchNorm1d"
228 }
229}
230
231pub struct BatchNorm2d {
241 pub weight: Parameter,
243 pub bias: Parameter,
245 running_mean: RwLock<Tensor<f32>>,
247 running_var: RwLock<Tensor<f32>>,
249 num_features: usize,
251 eps: f32,
253 momentum: f32,
255 training: AtomicBool,
257}
258
259impl BatchNorm2d {
260 pub fn new(num_features: usize) -> Self {
262 Self::with_options(num_features, 1e-5, 0.1)
263 }
264
265 pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
267 Self {
268 weight: Parameter::named("weight", ones(&[num_features]), true),
269 bias: Parameter::named("bias", zeros(&[num_features]), true),
270 running_mean: RwLock::new(zeros(&[num_features])),
271 running_var: RwLock::new(ones(&[num_features])),
272 num_features,
273 eps,
274 momentum,
275 training: AtomicBool::new(true),
276 }
277 }
278
279 pub fn num_features(&self) -> usize {
281 self.num_features
282 }
283}
284
285impl Module for BatchNorm2d {
286 fn forward(&self, input: &Variable) -> Variable {
287 let input_data = input.data();
288 let shape = input_data.shape().to_vec();
289 let batch_size = shape[0];
290 let channels = shape[1];
291 let height = shape[2];
292 let width = shape[3];
293 let spatial_size = height * width;
294
295 assert_eq!(
297 channels, self.num_features,
298 "BatchNorm2d: expected {} channels, got {}",
299 self.num_features, channels
300 );
301
302 let input_vec = input_data.to_vec();
303 let weight_vec = self.weight.data().to_vec();
304 let bias_vec = self.bias.data().to_vec();
305
306 let is_training = self.training.load(Ordering::Relaxed);
307
308 let mut means = vec![0.0f32; channels];
309 let mut vars = vec![0.0f32; channels];
310
311 if is_training {
312 for c in 0..channels {
313 let mut sum = 0.0f32;
314 for b in 0..batch_size {
315 for h in 0..height {
316 for w in 0..width {
317 let idx =
318 b * channels * spatial_size + c * spatial_size + h * width + w;
319 sum += input_vec[idx];
320 }
321 }
322 }
323 means[c] = sum / (batch_size * spatial_size) as f32;
324
325 let mut var_sum = 0.0f32;
326 for b in 0..batch_size {
327 for h in 0..height {
328 for w in 0..width {
329 let idx =
330 b * channels * spatial_size + c * spatial_size + h * width + w;
331 let diff = input_vec[idx] - means[c];
332 var_sum += diff * diff;
333 }
334 }
335 }
336 vars[c] = var_sum / (batch_size * spatial_size) as f32;
337 }
338
339 let mut running_mean = self.running_mean.write();
341 let mut running_var = self.running_var.write();
342 let running_mean_vec = running_mean.to_vec();
343 let running_var_vec = running_var.to_vec();
344
345 let new_mean: Vec<f32> = running_mean_vec
346 .iter()
347 .zip(means.iter())
348 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
349 .collect();
350 let new_var: Vec<f32> = running_var_vec
351 .iter()
352 .zip(vars.iter())
353 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
354 .collect();
355
356 *running_mean = Tensor::from_vec(new_mean, &[channels]).unwrap();
357 *running_var = Tensor::from_vec(new_var, &[channels]).unwrap();
358 } else {
359 means = self.running_mean.read().to_vec();
360 vars = self.running_var.read().to_vec();
361 }
362
363 let mut output_vec = vec![0.0f32; input_vec.len()];
364 for b in 0..batch_size {
365 for c in 0..channels {
366 for h in 0..height {
367 for w in 0..width {
368 let idx = b * channels * spatial_size + c * spatial_size + h * width + w;
369 let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
370 output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
371 }
372 }
373 }
374 }
375
376 let output = Tensor::from_vec(output_vec, &shape).unwrap();
377 Variable::new(output, input.requires_grad())
378 }
379
380 fn parameters(&self) -> Vec<Parameter> {
381 vec![self.weight.clone(), self.bias.clone()]
382 }
383
384 fn named_parameters(&self) -> HashMap<String, Parameter> {
385 let mut params = HashMap::new();
386 params.insert("weight".to_string(), self.weight.clone());
387 params.insert("bias".to_string(), self.bias.clone());
388 params
389 }
390
391 fn set_training(&mut self, training: bool) {
392 self.training.store(training, Ordering::Relaxed);
393 }
394
395 fn is_training(&self) -> bool {
396 self.training.load(Ordering::Relaxed)
397 }
398
399 fn name(&self) -> &'static str {
400 "BatchNorm2d"
401 }
402}
403
404pub struct LayerNorm {
414 pub weight: Parameter,
416 pub bias: Parameter,
418 normalized_shape: Vec<usize>,
420 eps: f32,
422}
423
424impl LayerNorm {
425 pub fn new(normalized_shape: Vec<usize>) -> Self {
427 Self::with_eps(normalized_shape, 1e-5)
428 }
429
430 pub fn single(size: usize) -> Self {
432 Self::new(vec![size])
433 }
434
435 pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
437 let numel: usize = normalized_shape.iter().product();
438 Self {
439 weight: Parameter::named("weight", ones(&[numel]), true),
440 bias: Parameter::named("bias", zeros(&[numel]), true),
441 normalized_shape,
442 eps,
443 }
444 }
445}
446
447impl Module for LayerNorm {
448 fn forward(&self, input: &Variable) -> Variable {
449 let input_data = input.data();
450 let shape = input_data.shape().to_vec();
451 let input_vec = input_data.to_vec();
452
453 let weight_vec = self.weight.data().to_vec();
454 let bias_vec = self.bias.data().to_vec();
455
456 let norm_size: usize = self.normalized_shape.iter().product();
458 let batch_size = input_vec.len() / norm_size;
459
460 let mut output_vec = vec![0.0f32; input_vec.len()];
461
462 for b in 0..batch_size {
463 let start = b * norm_size;
464 let end = start + norm_size;
465 let slice = &input_vec[start..end];
466
467 let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
469
470 let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
472
473 for i in 0..norm_size {
475 let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
476 output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
477 }
478 }
479
480 let output = Tensor::from_vec(output_vec, &shape).unwrap();
481 Variable::new(output, input.requires_grad())
482 }
483
484 fn parameters(&self) -> Vec<Parameter> {
485 vec![self.weight.clone(), self.bias.clone()]
486 }
487
488 fn named_parameters(&self) -> HashMap<String, Parameter> {
489 let mut params = HashMap::new();
490 params.insert("weight".to_string(), self.weight.clone());
491 params.insert("bias".to_string(), self.bias.clone());
492 params
493 }
494
495 fn name(&self) -> &'static str {
496 "LayerNorm"
497 }
498}
499
500pub struct GroupNorm {
513 pub weight: Parameter,
515 pub bias: Parameter,
517 num_groups: usize,
519 num_channels: usize,
521 eps: f32,
523 affine: bool,
525}
526
527impl GroupNorm {
528 pub fn new(num_groups: usize, num_channels: usize) -> Self {
534 Self::with_options(num_groups, num_channels, 1e-5, true)
535 }
536
537 pub fn with_options(num_groups: usize, num_channels: usize, eps: f32, affine: bool) -> Self {
539 assert!(
540 num_channels % num_groups == 0,
541 "num_channels ({}) must be divisible by num_groups ({})",
542 num_channels,
543 num_groups
544 );
545
546 Self {
547 weight: Parameter::named("weight", ones(&[num_channels]), affine),
548 bias: Parameter::named("bias", zeros(&[num_channels]), affine),
549 num_groups,
550 num_channels,
551 eps,
552 affine,
553 }
554 }
555}
556
557impl Module for GroupNorm {
558 fn forward(&self, input: &Variable) -> Variable {
559 let input_data = input.data();
560 let shape = input_data.shape().to_vec();
561 let batch_size = shape[0];
562 let channels = shape[1];
563 let spatial_size: usize = shape[2..].iter().product();
564
565 assert_eq!(
566 channels, self.num_channels,
567 "GroupNorm: expected {} channels, got {}",
568 self.num_channels, channels
569 );
570
571 let input_vec = input_data.to_vec();
572 let channels_per_group = channels / self.num_groups;
573
574 let mut output_vec = vec![0.0f32; input_vec.len()];
575
576 for b in 0..batch_size {
577 for g in 0..self.num_groups {
578 let mut sum = 0.0f32;
580 let group_size = channels_per_group * spatial_size;
581
582 for c in 0..channels_per_group {
583 let channel_idx = g * channels_per_group + c;
584 for s in 0..spatial_size {
585 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
586 sum += input_vec[idx];
587 }
588 }
589 let mean = sum / group_size as f32;
590
591 let mut var_sum = 0.0f32;
592 for c in 0..channels_per_group {
593 let channel_idx = g * channels_per_group + c;
594 for s in 0..spatial_size {
595 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
596 let diff = input_vec[idx] - mean;
597 var_sum += diff * diff;
598 }
599 }
600 let var = var_sum / group_size as f32;
601
602 let std_inv = 1.0 / (var + self.eps).sqrt();
604 for c in 0..channels_per_group {
605 let channel_idx = g * channels_per_group + c;
606 let weight = if self.affine {
607 self.weight.data().to_vec()[channel_idx]
608 } else {
609 1.0
610 };
611 let bias = if self.affine {
612 self.bias.data().to_vec()[channel_idx]
613 } else {
614 0.0
615 };
616
617 for s in 0..spatial_size {
618 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
619 let normalized = (input_vec[idx] - mean) * std_inv;
620 output_vec[idx] = normalized * weight + bias;
621 }
622 }
623 }
624 }
625
626 let output = Tensor::from_vec(output_vec, &shape).unwrap();
627 Variable::new(output, input.requires_grad())
628 }
629
630 fn parameters(&self) -> Vec<Parameter> {
631 if self.affine {
632 vec![self.weight.clone(), self.bias.clone()]
633 } else {
634 vec![]
635 }
636 }
637
638 fn named_parameters(&self) -> HashMap<String, Parameter> {
639 if self.affine {
640 let mut params = HashMap::new();
641 params.insert("weight".to_string(), self.weight.clone());
642 params.insert("bias".to_string(), self.bias.clone());
643 params
644 } else {
645 HashMap::new()
646 }
647 }
648
649 fn name(&self) -> &'static str {
650 "GroupNorm"
651 }
652}
653
654pub struct InstanceNorm2d {
667 pub weight: Parameter,
669 pub bias: Parameter,
671 num_features: usize,
673 eps: f32,
675 affine: bool,
677}
678
679impl InstanceNorm2d {
680 pub fn new(num_features: usize) -> Self {
682 Self::with_options(num_features, 1e-5, false)
683 }
684
685 pub fn with_affine(num_features: usize) -> Self {
687 Self::with_options(num_features, 1e-5, true)
688 }
689
690 pub fn with_options(num_features: usize, eps: f32, affine: bool) -> Self {
692 Self {
693 weight: Parameter::named("weight", ones(&[num_features]), affine),
694 bias: Parameter::named("bias", zeros(&[num_features]), affine),
695 num_features,
696 eps,
697 affine,
698 }
699 }
700}
701
702impl Module for InstanceNorm2d {
703 fn forward(&self, input: &Variable) -> Variable {
704 let input_data = input.data();
705 let shape = input_data.shape().to_vec();
706
707 assert!(
708 shape.len() == 4,
709 "InstanceNorm2d expects 4D input (N, C, H, W)"
710 );
711
712 let batch_size = shape[0];
713 let channels = shape[1];
714 let height = shape[2];
715 let width = shape[3];
716 let spatial_size = height * width;
717
718 assert_eq!(
719 channels, self.num_features,
720 "InstanceNorm2d: expected {} channels, got {}",
721 self.num_features, channels
722 );
723
724 let input_vec = input_data.to_vec();
725 let mut output_vec = vec![0.0f32; input_vec.len()];
726
727 for b in 0..batch_size {
728 for c in 0..channels {
729 let mut sum = 0.0f32;
731 for s in 0..spatial_size {
732 let idx = b * channels * spatial_size + c * spatial_size + s;
733 sum += input_vec[idx];
734 }
735 let mean = sum / spatial_size as f32;
736
737 let mut var_sum = 0.0f32;
739 for s in 0..spatial_size {
740 let idx = b * channels * spatial_size + c * spatial_size + s;
741 let diff = input_vec[idx] - mean;
742 var_sum += diff * diff;
743 }
744 let var = var_sum / spatial_size as f32;
745
746 let std_inv = 1.0 / (var + self.eps).sqrt();
748 let weight = if self.affine {
749 self.weight.data().to_vec()[c]
750 } else {
751 1.0
752 };
753 let bias = if self.affine {
754 self.bias.data().to_vec()[c]
755 } else {
756 0.0
757 };
758
759 for s in 0..spatial_size {
760 let idx = b * channels * spatial_size + c * spatial_size + s;
761 let normalized = (input_vec[idx] - mean) * std_inv;
762 output_vec[idx] = normalized * weight + bias;
763 }
764 }
765 }
766
767 let output = Tensor::from_vec(output_vec, &shape).unwrap();
768 Variable::new(output, input.requires_grad())
769 }
770
771 fn parameters(&self) -> Vec<Parameter> {
772 if self.affine {
773 vec![self.weight.clone(), self.bias.clone()]
774 } else {
775 vec![]
776 }
777 }
778
779 fn named_parameters(&self) -> HashMap<String, Parameter> {
780 if self.affine {
781 let mut params = HashMap::new();
782 params.insert("weight".to_string(), self.weight.clone());
783 params.insert("bias".to_string(), self.bias.clone());
784 params
785 } else {
786 HashMap::new()
787 }
788 }
789
790 fn name(&self) -> &'static str {
791 "InstanceNorm2d"
792 }
793}
794
795#[cfg(test)]
800mod tests {
801 use super::*;
802
803 #[test]
804 fn test_batchnorm1d() {
805 let bn = BatchNorm1d::new(3);
806 let input = Variable::new(
807 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
808 false,
809 );
810 let output = bn.forward(&input);
811 assert_eq!(output.shape(), vec![2, 3]);
812 }
813
814 #[test]
815 fn test_batchnorm2d() {
816 let bn = BatchNorm2d::new(2);
817 let input = Variable::new(
818 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
819 false,
820 );
821 let output = bn.forward(&input);
822 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
823 }
824
825 #[test]
826 fn test_layernorm() {
827 let ln = LayerNorm::single(4);
828 let input = Variable::new(
829 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap(),
830 false,
831 );
832 let output = ln.forward(&input);
833 assert_eq!(output.shape(), vec![2, 4]);
834 }
835
836 #[test]
837 fn test_batchnorm_parameters() {
838 let bn = BatchNorm1d::new(10);
839 assert_eq!(bn.parameters().len(), 2);
840 assert_eq!(bn.num_parameters(), 20); }
842
843 #[test]
844 fn test_groupnorm() {
845 let gn = GroupNorm::new(2, 4); let input = Variable::new(
847 Tensor::from_vec(vec![1.0; 32], &[2, 4, 2, 2]).unwrap(),
848 false,
849 );
850 let output = gn.forward(&input);
851 assert_eq!(output.shape(), vec![2, 4, 2, 2]);
852 }
853
854 #[test]
855 fn test_groupnorm_normalization() {
856 let gn = GroupNorm::with_options(2, 4, 1e-5, false); let input = Variable::new(
858 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 4, 1, 2]).unwrap(),
859 false,
860 );
861 let output = gn.forward(&input);
862 let out_vec = output.data().to_vec();
864 let group1_mean: f32 = out_vec[0..4].iter().sum::<f32>() / 4.0;
866 let group2_mean: f32 = out_vec[4..8].iter().sum::<f32>() / 4.0;
867 assert!(group1_mean.abs() < 1e-5);
868 assert!(group2_mean.abs() < 1e-5);
869 }
870
871 #[test]
872 fn test_instancenorm2d() {
873 let inn = InstanceNorm2d::new(2);
874 let input = Variable::new(
875 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
876 false,
877 );
878 let output = inn.forward(&input);
879 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
880 }
881
882 #[test]
883 fn test_instancenorm2d_with_affine() {
884 let inn = InstanceNorm2d::with_affine(4);
885 let input = Variable::new(
886 Tensor::from_vec(vec![1.0; 64], &[1, 4, 4, 4]).unwrap(),
887 false,
888 );
889 let output = inn.forward(&input);
890 assert_eq!(output.shape(), vec![1, 4, 4, 4]);
891 assert_eq!(inn.parameters().len(), 2);
892 }
893}