1use std::collections::HashMap;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::Variable;
21use axonml_autograd::functions::{
22 BatchNorm1dBackward, BatchNorm2dBackward, GroupNormBackward, InstanceNorm2dBackward,
23 LayerNormBackward,
24};
25use axonml_autograd::grad_fn::GradFn;
26use axonml_autograd::no_grad::is_grad_enabled;
27use axonml_tensor::Tensor;
28use parking_lot::RwLock;
29
30use crate::init::{ones, zeros};
31use crate::module::Module;
32use crate::parameter::Parameter;
33
34pub struct BatchNorm1d {
46 pub weight: Parameter,
48 pub bias: Parameter,
50 running_mean: RwLock<Tensor<f32>>,
52 running_var: RwLock<Tensor<f32>>,
54 num_features: usize,
56 eps: f32,
58 momentum: f32,
60 track_running_stats: bool,
62 training: AtomicBool,
64}
65
66impl BatchNorm1d {
67 pub fn new(num_features: usize) -> Self {
69 Self::with_options(num_features, 1e-5, 0.1, true)
70 }
71
72 pub fn with_options(
74 num_features: usize,
75 eps: f32,
76 momentum: f32,
77 track_running_stats: bool,
78 ) -> Self {
79 Self {
80 weight: Parameter::named("weight", ones(&[num_features]), true),
81 bias: Parameter::named("bias", zeros(&[num_features]), true),
82 running_mean: RwLock::new(zeros(&[num_features])),
83 running_var: RwLock::new(ones(&[num_features])),
84 num_features,
85 eps,
86 momentum,
87 track_running_stats,
88 training: AtomicBool::new(true),
89 }
90 }
91
92 pub fn num_features(&self) -> usize {
94 self.num_features
95 }
96}
97
98impl Module for BatchNorm1d {
99 fn forward(&self, input: &Variable) -> Variable {
100 let input_data = input.data();
101 let shape = input_data.shape().to_vec();
102 let batch_size = shape[0];
103 let num_features = shape[1];
104
105 assert_eq!(
107 num_features, self.num_features,
108 "BatchNorm1d: expected {} features, got {}",
109 self.num_features, num_features
110 );
111
112 let is_training = self.training.load(Ordering::Relaxed);
113 let spatial_size: usize = if shape.len() > 2 {
114 shape[2..].iter().product()
115 } else {
116 1
117 };
118
119 #[cfg(feature = "cuda")]
124 if input_data.device().is_gpu() && is_training {
125 let gamma_data = self.weight.data();
126 let beta_data = self.bias.data();
127
128 let gamma_gpu = if !gamma_data.device().is_gpu() {
130 gamma_data
131 .to_device(input_data.device())
132 .unwrap_or(gamma_data)
133 } else {
134 gamma_data
135 };
136 let beta_gpu = if !beta_data.device().is_gpu() {
137 beta_data
138 .to_device(input_data.device())
139 .unwrap_or(beta_data)
140 } else {
141 beta_data
142 };
143
144 if let Some((output_tensor, means, vars)) = input_data.batchnorm_fused(
145 &gamma_gpu,
146 &beta_gpu,
147 self.eps,
148 num_features,
149 spatial_size,
150 ) {
151 if self.track_running_stats {
153 let mut running_mean = self.running_mean.write();
154 let mut running_var = self.running_var.write();
155 let running_mean_vec = running_mean.to_vec();
156 let running_var_vec = running_var.to_vec();
157 let new_mean: Vec<f32> = running_mean_vec
158 .iter()
159 .zip(means.iter())
160 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
161 .collect();
162 let new_var: Vec<f32> = running_var_vec
163 .iter()
164 .zip(vars.iter())
165 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
166 .collect();
167 *running_mean = Tensor::from_vec(new_mean, &[num_features]).unwrap();
168 *running_var = Tensor::from_vec(new_var, &[num_features]).unwrap();
169 }
170
171 let weight_vec = gamma_gpu.to_vec();
172 let requires_grad =
173 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
174 if requires_grad {
175 let weight_var = self.weight.variable();
176 let bias_var = self.bias.variable();
177 let grad_fn = GradFn::new(BatchNorm1dBackward::new(
178 input.grad_fn().cloned(),
179 weight_var.grad_fn().cloned(),
180 bias_var.grad_fn().cloned(),
181 input_data,
182 means,
183 vars,
184 weight_vec,
185 self.eps,
186 self.num_features,
187 ));
188 return Variable::from_operation(output_tensor, grad_fn, true);
189 }
190 return Variable::new(output_tensor, false);
191 }
192 }
193
194 let input_vec = input_data.to_vec();
195 let weight_vec = self.weight.data().to_vec();
196 let bias_vec = self.bias.data().to_vec();
197
198 let mut means = vec![0.0f32; num_features];
199 let mut vars = vec![0.0f32; num_features];
200
201 if is_training {
202 for c in 0..num_features {
204 let mut sum = 0.0f32;
205 for b in 0..batch_size {
206 for s in 0..spatial_size {
207 let idx = b * num_features * spatial_size + c * spatial_size + s;
208 sum += input_vec[idx];
209 }
210 }
211 means[c] = sum / (batch_size * spatial_size) as f32;
212
213 let mut var_sum = 0.0f32;
214 for b in 0..batch_size {
215 for s in 0..spatial_size {
216 let idx = b * num_features * spatial_size + c * spatial_size + s;
217 let diff = input_vec[idx] - means[c];
218 var_sum += diff * diff;
219 }
220 }
221 vars[c] = var_sum / (batch_size * spatial_size) as f32;
222 }
223
224 if self.track_running_stats {
226 let mut running_mean = self.running_mean.write();
227 let mut running_var = self.running_var.write();
228 let running_mean_vec = running_mean.to_vec();
229 let running_var_vec = running_var.to_vec();
230
231 let new_mean: Vec<f32> = running_mean_vec
232 .iter()
233 .zip(means.iter())
234 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
235 .collect();
236 let new_var: Vec<f32> = running_var_vec
237 .iter()
238 .zip(vars.iter())
239 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
240 .collect();
241
242 *running_mean = Tensor::from_vec(new_mean, &[num_features]).unwrap();
243 *running_var = Tensor::from_vec(new_var, &[num_features]).unwrap();
244 }
245 } else {
246 means = self.running_mean.read().to_vec();
248 vars = self.running_var.read().to_vec();
249 }
250
251 let mut output_vec = vec![0.0f32; input_vec.len()];
253 for b in 0..batch_size {
254 for c in 0..num_features {
255 for s in 0..spatial_size {
256 let idx = b * num_features * spatial_size + c * spatial_size + s;
257 let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
258 output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
259 }
260 }
261 }
262
263 let output = Tensor::from_vec(output_vec, &shape).unwrap();
264
265 let requires_grad =
266 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
267 if requires_grad {
268 let weight_var = self.weight.variable();
269 let bias_var = self.bias.variable();
270
271 let grad_fn = GradFn::new(BatchNorm1dBackward::new(
272 input.grad_fn().cloned(),
273 weight_var.grad_fn().cloned(),
274 bias_var.grad_fn().cloned(),
275 input_data,
276 means.clone(),
277 vars.clone(),
278 weight_vec,
279 self.eps,
280 self.num_features,
281 ));
282 Variable::from_operation(output, grad_fn, true)
283 } else {
284 Variable::new(output, false)
285 }
286 }
287
288 fn parameters(&self) -> Vec<Parameter> {
289 vec![self.weight.clone(), self.bias.clone()]
290 }
291
292 fn named_parameters(&self) -> HashMap<String, Parameter> {
293 let mut params = HashMap::new();
294 params.insert("weight".to_string(), self.weight.clone());
295 params.insert("bias".to_string(), self.bias.clone());
296 params
297 }
298
299 fn set_training(&mut self, training: bool) {
300 self.training.store(training, Ordering::Relaxed);
301 }
302
303 fn is_training(&self) -> bool {
304 self.training.load(Ordering::Relaxed)
305 }
306
307 fn name(&self) -> &'static str {
308 "BatchNorm1d"
309 }
310}
311
312pub struct BatchNorm2d {
322 pub weight: Parameter,
324 pub bias: Parameter,
326 running_mean: RwLock<Tensor<f32>>,
328 running_var: RwLock<Tensor<f32>>,
330 num_features: usize,
332 eps: f32,
334 momentum: f32,
336 training: AtomicBool,
338}
339
340impl BatchNorm2d {
341 pub fn new(num_features: usize) -> Self {
343 Self::with_options(num_features, 1e-5, 0.1)
344 }
345
346 pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
348 Self {
349 weight: Parameter::named("weight", ones(&[num_features]), true),
350 bias: Parameter::named("bias", zeros(&[num_features]), true),
351 running_mean: RwLock::new(zeros(&[num_features])),
352 running_var: RwLock::new(ones(&[num_features])),
353 num_features,
354 eps,
355 momentum,
356 training: AtomicBool::new(true),
357 }
358 }
359
360 pub fn num_features(&self) -> usize {
362 self.num_features
363 }
364}
365
366impl Module for BatchNorm2d {
367 fn forward(&self, input: &Variable) -> Variable {
368 let input_data = input.data();
369 let shape = input_data.shape().to_vec();
370 let batch_size = shape[0];
371 let channels = shape[1];
372 let height = shape[2];
373 let width = shape[3];
374 let spatial_size = height * width;
375
376 assert_eq!(
378 channels, self.num_features,
379 "BatchNorm2d: expected {} channels, got {}",
380 self.num_features, channels
381 );
382
383 let is_training = self.training.load(Ordering::Relaxed);
384
385 #[cfg(feature = "cuda")]
387 if input_data.device().is_gpu() && is_training {
388 let gamma_data = self.weight.data();
389 let beta_data = self.bias.data();
390
391 let gamma_gpu = if !gamma_data.device().is_gpu() {
393 gamma_data
394 .to_device(input_data.device())
395 .unwrap_or(gamma_data)
396 } else {
397 gamma_data
398 };
399 let beta_gpu = if !beta_data.device().is_gpu() {
400 beta_data
401 .to_device(input_data.device())
402 .unwrap_or(beta_data)
403 } else {
404 beta_data
405 };
406
407 if let Some((output_tensor, means, vars)) =
408 input_data.batchnorm_fused(&gamma_gpu, &beta_gpu, self.eps, channels, spatial_size)
409 {
410 let mut running_mean = self.running_mean.write();
412 let mut running_var = self.running_var.write();
413 let running_mean_vec = running_mean.to_vec();
414 let running_var_vec = running_var.to_vec();
415 let new_mean: Vec<f32> = running_mean_vec
416 .iter()
417 .zip(means.iter())
418 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
419 .collect();
420 let new_var: Vec<f32> = running_var_vec
421 .iter()
422 .zip(vars.iter())
423 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
424 .collect();
425 *running_mean = Tensor::from_vec(new_mean, &[channels]).unwrap();
426 *running_var = Tensor::from_vec(new_var, &[channels]).unwrap();
427
428 let weight_vec = gamma_gpu.to_vec();
429 let requires_grad =
430 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
431 if requires_grad {
432 let weight_var = self.weight.variable();
433 let bias_var = self.bias.variable();
434 let grad_fn = GradFn::new(BatchNorm2dBackward::new(
435 input.grad_fn().cloned(),
436 weight_var.grad_fn().cloned(),
437 bias_var.grad_fn().cloned(),
438 input_data,
439 means,
440 vars,
441 weight_vec,
442 self.eps,
443 self.num_features,
444 ));
445 return Variable::from_operation(output_tensor, grad_fn, true);
446 }
447 return Variable::new(output_tensor, false);
448 }
449 }
450
451 let input_vec = input_data.to_vec();
453 let weight_vec = self.weight.data().to_vec();
454 let bias_vec = self.bias.data().to_vec();
455
456 let mut means = vec![0.0f32; channels];
457 let mut vars = vec![0.0f32; channels];
458
459 if is_training {
460 let n_per_channel = (batch_size * spatial_size) as f32;
461 for c in 0..channels {
462 let mut sum = 0.0f32;
463 let mut sum_sq = 0.0f32;
464 for b in 0..batch_size {
465 let base = b * channels * spatial_size + c * spatial_size;
466 for s in 0..spatial_size {
467 let val = input_vec[base + s];
468 sum += val;
469 sum_sq += val * val;
470 }
471 }
472 means[c] = sum / n_per_channel;
473 vars[c] = sum_sq / n_per_channel - means[c] * means[c];
474 }
475
476 let mut running_mean = self.running_mean.write();
478 let mut running_var = self.running_var.write();
479 let running_mean_vec = running_mean.to_vec();
480 let running_var_vec = running_var.to_vec();
481
482 let new_mean: Vec<f32> = running_mean_vec
483 .iter()
484 .zip(means.iter())
485 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
486 .collect();
487 let new_var: Vec<f32> = running_var_vec
488 .iter()
489 .zip(vars.iter())
490 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
491 .collect();
492
493 *running_mean = Tensor::from_vec(new_mean, &[channels]).unwrap();
494 *running_var = Tensor::from_vec(new_var, &[channels]).unwrap();
495 } else {
496 means = self.running_mean.read().to_vec();
497 vars = self.running_var.read().to_vec();
498 }
499
500 let total = input_vec.len();
502 let mut output_vec = vec![0.0f32; total];
503
504 let inv_stds: Vec<f32> = vars.iter().map(|v| 1.0 / (v + self.eps).sqrt()).collect();
506
507 for i in 0..total {
508 let c = (i / spatial_size) % channels;
509 output_vec[i] = (input_vec[i] - means[c]) * inv_stds[c] * weight_vec[c] + bias_vec[c];
510 }
511
512 let output = Tensor::from_vec(output_vec, &shape).unwrap();
513
514 let requires_grad =
515 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
516 if requires_grad {
517 let weight_var = self.weight.variable();
518 let bias_var = self.bias.variable();
519
520 let grad_fn = GradFn::new(BatchNorm2dBackward::new(
521 input.grad_fn().cloned(),
522 weight_var.grad_fn().cloned(),
523 bias_var.grad_fn().cloned(),
524 input_data,
525 means.clone(),
526 vars.clone(),
527 weight_vec,
528 self.eps,
529 self.num_features,
530 ));
531 Variable::from_operation(output, grad_fn, true)
532 } else {
533 Variable::new(output, false)
534 }
535 }
536
537 fn parameters(&self) -> Vec<Parameter> {
538 vec![self.weight.clone(), self.bias.clone()]
539 }
540
541 fn named_parameters(&self) -> HashMap<String, Parameter> {
542 let mut params = HashMap::new();
543 params.insert("weight".to_string(), self.weight.clone());
544 params.insert("bias".to_string(), self.bias.clone());
545 params
546 }
547
548 fn set_training(&mut self, training: bool) {
549 self.training.store(training, Ordering::Relaxed);
550 }
551
552 fn is_training(&self) -> bool {
553 self.training.load(Ordering::Relaxed)
554 }
555
556 fn name(&self) -> &'static str {
557 "BatchNorm2d"
558 }
559}
560
561pub struct LayerNorm {
571 pub weight: Parameter,
573 pub bias: Parameter,
575 normalized_shape: Vec<usize>,
577 eps: f32,
579}
580
581impl LayerNorm {
582 pub fn new(normalized_shape: Vec<usize>) -> Self {
584 Self::with_eps(normalized_shape, 1e-5)
585 }
586
587 pub fn single(size: usize) -> Self {
589 Self::new(vec![size])
590 }
591
592 pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
594 let numel: usize = normalized_shape.iter().product();
595 Self {
596 weight: Parameter::named("weight", ones(&[numel]), true),
597 bias: Parameter::named("bias", zeros(&[numel]), true),
598 normalized_shape,
599 eps,
600 }
601 }
602}
603
604impl Module for LayerNorm {
605 fn forward(&self, input: &Variable) -> Variable {
606 let input_data = input.data();
607 let shape = input_data.shape().to_vec();
608 let norm_size: usize = self.normalized_shape.iter().product();
609 let total_len = input_data.numel();
610 let num_rows = total_len / norm_size;
611
612 #[cfg(feature = "cuda")]
614 if input_data.device().is_gpu() {
615 let weight_data = self.weight.data();
617 let weight_gpu = if weight_data.device().is_gpu() {
618 weight_data.clone()
619 } else {
620 weight_data.to_device(input_data.device().clone()).unwrap()
621 };
622 let bias_data = self.bias.data();
623 let bias_gpu = if bias_data.device().is_gpu() {
624 bias_data.clone()
625 } else {
626 bias_data.to_device(input_data.device().clone()).unwrap()
627 };
628
629 let output = input_data
630 .layer_norm_cuda(&weight_gpu, &bias_gpu, norm_size, self.eps)
631 .expect("CUDA LayerNorm failed");
632
633 let requires_grad = input.requires_grad() && is_grad_enabled();
634 return if requires_grad {
635 let grad_fn = GradFn::new(LayerNormBackward::new(
636 input.grad_fn().cloned(),
637 self.weight.variable().grad_fn().cloned(),
638 self.bias.variable().grad_fn().cloned(),
639 input_data.clone(),
640 self.weight.data().clone(),
641 self.normalized_shape.clone(),
642 self.eps,
643 ));
644 Variable::from_operation(output, grad_fn, true)
645 } else {
646 Variable::from_tensor(output)
647 };
648 }
649
650 let input_vec = input_data.to_vec();
652 let weight_vec = self.weight.data().to_vec();
653 let bias_vec = self.bias.data().to_vec();
654
655 let mut output_vec = vec![0.0f32; input_vec.len()];
656
657 for b in 0..num_rows {
658 let start = b * norm_size;
659 let end = start + norm_size;
660 let slice = &input_vec[start..end];
661
662 let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
663 let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
664
665 for i in 0..norm_size {
666 let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
667 output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
668 }
669 }
670
671 let output = Tensor::from_vec(output_vec, &shape).unwrap();
672 let requires_grad = input.requires_grad() && is_grad_enabled();
673
674 if requires_grad {
675 let grad_fn = GradFn::new(LayerNormBackward::new(
676 input.grad_fn().cloned(),
677 self.weight.variable().grad_fn().cloned(),
678 self.bias.variable().grad_fn().cloned(),
679 input_data.clone(),
680 self.weight.data().clone(),
681 self.normalized_shape.clone(),
682 self.eps,
683 ));
684 Variable::from_operation(output, grad_fn, true)
685 } else {
686 Variable::from_tensor(output)
687 }
688 }
689
690 fn parameters(&self) -> Vec<Parameter> {
691 vec![self.weight.clone(), self.bias.clone()]
692 }
693
694 fn named_parameters(&self) -> HashMap<String, Parameter> {
695 let mut params = HashMap::new();
696 params.insert("weight".to_string(), self.weight.clone());
697 params.insert("bias".to_string(), self.bias.clone());
698 params
699 }
700
701 fn name(&self) -> &'static str {
702 "LayerNorm"
703 }
704}
705
706pub struct GroupNorm {
719 pub weight: Parameter,
721 pub bias: Parameter,
723 num_groups: usize,
725 num_channels: usize,
727 eps: f32,
729 affine: bool,
731}
732
733impl GroupNorm {
734 pub fn new(num_groups: usize, num_channels: usize) -> Self {
740 Self::with_options(num_groups, num_channels, 1e-5, true)
741 }
742
743 pub fn with_options(num_groups: usize, num_channels: usize, eps: f32, affine: bool) -> Self {
745 assert!(
746 num_channels % num_groups == 0,
747 "num_channels ({}) must be divisible by num_groups ({})",
748 num_channels,
749 num_groups
750 );
751
752 Self {
753 weight: Parameter::named("weight", ones(&[num_channels]), affine),
754 bias: Parameter::named("bias", zeros(&[num_channels]), affine),
755 num_groups,
756 num_channels,
757 eps,
758 affine,
759 }
760 }
761}
762
763impl Module for GroupNorm {
764 fn forward(&self, input: &Variable) -> Variable {
765 let input_data = input.data();
766 let shape = input_data.shape().to_vec();
767 let batch_size = shape[0];
768 let channels = shape[1];
769 let spatial_size: usize = shape[2..].iter().product();
770
771 assert_eq!(
772 channels, self.num_channels,
773 "GroupNorm: expected {} channels, got {}",
774 self.num_channels, channels
775 );
776
777 let input_vec = input_data.to_vec();
778 let channels_per_group = channels / self.num_groups;
779
780 let mut output_vec = vec![0.0f32; input_vec.len()];
781
782 for b in 0..batch_size {
783 for g in 0..self.num_groups {
784 let mut sum = 0.0f32;
786 let group_size = channels_per_group * spatial_size;
787
788 for c in 0..channels_per_group {
789 let channel_idx = g * channels_per_group + c;
790 for s in 0..spatial_size {
791 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
792 sum += input_vec[idx];
793 }
794 }
795 let mean = sum / group_size as f32;
796
797 let mut var_sum = 0.0f32;
798 for c in 0..channels_per_group {
799 let channel_idx = g * channels_per_group + c;
800 for s in 0..spatial_size {
801 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
802 let diff = input_vec[idx] - mean;
803 var_sum += diff * diff;
804 }
805 }
806 let var = var_sum / group_size as f32;
807
808 let std_inv = 1.0 / (var + self.eps).sqrt();
810 for c in 0..channels_per_group {
811 let channel_idx = g * channels_per_group + c;
812 let weight = if self.affine {
813 self.weight.data().to_vec()[channel_idx]
814 } else {
815 1.0
816 };
817 let bias = if self.affine {
818 self.bias.data().to_vec()[channel_idx]
819 } else {
820 0.0
821 };
822
823 for s in 0..spatial_size {
824 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
825 let normalized = (input_vec[idx] - mean) * std_inv;
826 output_vec[idx] = normalized * weight + bias;
827 }
828 }
829 }
830 }
831
832 let output = Tensor::from_vec(output_vec, &shape).unwrap();
833 let requires_grad = input.requires_grad() && is_grad_enabled();
834
835 if requires_grad && self.affine {
836 let grad_fn = GradFn::new(GroupNormBackward::new(
837 input.grad_fn().cloned(),
838 self.weight.variable().grad_fn().cloned(),
839 self.bias.variable().grad_fn().cloned(),
840 input_data.clone(),
841 self.weight.data().clone(),
842 self.num_groups,
843 self.eps,
844 ));
845 Variable::from_operation(output, grad_fn, true)
846 } else {
847 Variable::from_tensor(output)
848 }
849 }
850
851 fn parameters(&self) -> Vec<Parameter> {
852 if self.affine {
853 vec![self.weight.clone(), self.bias.clone()]
854 } else {
855 vec![]
856 }
857 }
858
859 fn named_parameters(&self) -> HashMap<String, Parameter> {
860 if self.affine {
861 let mut params = HashMap::new();
862 params.insert("weight".to_string(), self.weight.clone());
863 params.insert("bias".to_string(), self.bias.clone());
864 params
865 } else {
866 HashMap::new()
867 }
868 }
869
870 fn name(&self) -> &'static str {
871 "GroupNorm"
872 }
873}
874
875pub struct InstanceNorm2d {
888 pub weight: Parameter,
890 pub bias: Parameter,
892 num_features: usize,
894 eps: f32,
896 affine: bool,
898}
899
900impl InstanceNorm2d {
901 pub fn new(num_features: usize) -> Self {
903 Self::with_options(num_features, 1e-5, false)
904 }
905
906 pub fn with_affine(num_features: usize) -> Self {
908 Self::with_options(num_features, 1e-5, true)
909 }
910
911 pub fn with_options(num_features: usize, eps: f32, affine: bool) -> Self {
913 Self {
914 weight: Parameter::named("weight", ones(&[num_features]), affine),
915 bias: Parameter::named("bias", zeros(&[num_features]), affine),
916 num_features,
917 eps,
918 affine,
919 }
920 }
921}
922
923impl Module for InstanceNorm2d {
924 fn forward(&self, input: &Variable) -> Variable {
925 let input_data = input.data();
926 let shape = input_data.shape().to_vec();
927
928 assert!(
929 shape.len() == 4,
930 "InstanceNorm2d expects 4D input (N, C, H, W)"
931 );
932
933 let batch_size = shape[0];
934 let channels = shape[1];
935 let height = shape[2];
936 let width = shape[3];
937 let spatial_size = height * width;
938
939 assert_eq!(
940 channels, self.num_features,
941 "InstanceNorm2d: expected {} channels, got {}",
942 self.num_features, channels
943 );
944
945 let input_vec = input_data.to_vec();
946 let mut output_vec = vec![0.0f32; input_vec.len()];
947
948 for b in 0..batch_size {
949 for c in 0..channels {
950 let mut sum = 0.0f32;
952 for s in 0..spatial_size {
953 let idx = b * channels * spatial_size + c * spatial_size + s;
954 sum += input_vec[idx];
955 }
956 let mean = sum / spatial_size as f32;
957
958 let mut var_sum = 0.0f32;
960 for s in 0..spatial_size {
961 let idx = b * channels * spatial_size + c * spatial_size + s;
962 let diff = input_vec[idx] - mean;
963 var_sum += diff * diff;
964 }
965 let var = var_sum / spatial_size as f32;
966
967 let std_inv = 1.0 / (var + self.eps).sqrt();
969 let weight = if self.affine {
970 self.weight.data().to_vec()[c]
971 } else {
972 1.0
973 };
974 let bias = if self.affine {
975 self.bias.data().to_vec()[c]
976 } else {
977 0.0
978 };
979
980 for s in 0..spatial_size {
981 let idx = b * channels * spatial_size + c * spatial_size + s;
982 let normalized = (input_vec[idx] - mean) * std_inv;
983 output_vec[idx] = normalized * weight + bias;
984 }
985 }
986 }
987
988 let output = Tensor::from_vec(output_vec, &shape).unwrap();
989 let requires_grad = input.requires_grad() && is_grad_enabled();
990
991 if requires_grad {
992 let grad_fn = GradFn::new(InstanceNorm2dBackward::new(
993 input.grad_fn().cloned(),
994 if self.affine {
995 self.weight.variable().grad_fn().cloned()
996 } else {
997 None
998 },
999 if self.affine {
1000 self.bias.variable().grad_fn().cloned()
1001 } else {
1002 None
1003 },
1004 input_data.clone(),
1005 self.weight.data().clone(),
1006 self.eps,
1007 self.affine,
1008 ));
1009 Variable::from_operation(output, grad_fn, true)
1010 } else {
1011 Variable::from_tensor(output)
1012 }
1013 }
1014
1015 fn parameters(&self) -> Vec<Parameter> {
1016 if self.affine {
1017 vec![self.weight.clone(), self.bias.clone()]
1018 } else {
1019 vec![]
1020 }
1021 }
1022
1023 fn named_parameters(&self) -> HashMap<String, Parameter> {
1024 if self.affine {
1025 let mut params = HashMap::new();
1026 params.insert("weight".to_string(), self.weight.clone());
1027 params.insert("bias".to_string(), self.bias.clone());
1028 params
1029 } else {
1030 HashMap::new()
1031 }
1032 }
1033
1034 fn name(&self) -> &'static str {
1035 "InstanceNorm2d"
1036 }
1037}
1038
1039#[cfg(test)]
1044mod tests {
1045 use super::*;
1046
1047 #[test]
1048 fn test_batchnorm1d() {
1049 let bn = BatchNorm1d::new(3);
1050 let input = Variable::new(
1051 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap(),
1052 false,
1053 );
1054 let output = bn.forward(&input);
1055 assert_eq!(output.shape(), vec![2, 3]);
1056 }
1057
1058 #[test]
1059 fn test_batchnorm2d() {
1060 let bn = BatchNorm2d::new(2);
1061 let input = Variable::new(
1062 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
1063 false,
1064 );
1065 let output = bn.forward(&input);
1066 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1067 }
1068
1069 #[test]
1070 fn test_layernorm() {
1071 let ln = LayerNorm::single(4);
1072 let input = Variable::new(
1073 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).unwrap(),
1074 false,
1075 );
1076 let output = ln.forward(&input);
1077 assert_eq!(output.shape(), vec![2, 4]);
1078 }
1079
1080 #[test]
1081 fn test_batchnorm_parameters() {
1082 let bn = BatchNorm1d::new(10);
1083 assert_eq!(bn.parameters().len(), 2);
1084 assert_eq!(bn.num_parameters(), 20); }
1086
1087 #[test]
1088 fn test_groupnorm() {
1089 let gn = GroupNorm::new(2, 4); let input = Variable::new(
1091 Tensor::from_vec(vec![1.0; 32], &[2, 4, 2, 2]).unwrap(),
1092 false,
1093 );
1094 let output = gn.forward(&input);
1095 assert_eq!(output.shape(), vec![2, 4, 2, 2]);
1096 }
1097
1098 #[test]
1099 fn test_groupnorm_normalization() {
1100 let gn = GroupNorm::with_options(2, 4, 1e-5, false); let input = Variable::new(
1102 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 4, 1, 2]).unwrap(),
1103 false,
1104 );
1105 let output = gn.forward(&input);
1106 let out_vec = output.data().to_vec();
1108 let group1_mean: f32 = out_vec[0..4].iter().sum::<f32>() / 4.0;
1110 let group2_mean: f32 = out_vec[4..8].iter().sum::<f32>() / 4.0;
1111 assert!(group1_mean.abs() < 1e-5);
1112 assert!(group2_mean.abs() < 1e-5);
1113 }
1114
1115 #[test]
1116 fn test_instancenorm2d() {
1117 let inn = InstanceNorm2d::new(2);
1118 let input = Variable::new(
1119 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).unwrap(),
1120 false,
1121 );
1122 let output = inn.forward(&input);
1123 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1124 }
1125
1126 #[test]
1127 fn test_instancenorm2d_with_affine() {
1128 let inn = InstanceNorm2d::with_affine(4);
1129 let input = Variable::new(
1130 Tensor::from_vec(vec![1.0; 64], &[1, 4, 4, 4]).unwrap(),
1131 false,
1132 );
1133 let output = inn.forward(&input);
1134 assert_eq!(output.shape(), vec![1, 4, 4, 4]);
1135 assert_eq!(inn.parameters().len(), 2);
1136 }
1137}