1use std::collections::HashMap;
18use std::sync::atomic::{AtomicBool, Ordering};
19
20use axonml_autograd::Variable;
21use axonml_autograd::functions::{
22 BatchNorm1dBackward, BatchNorm2dBackward, GroupNormBackward, InstanceNorm2dBackward,
23 LayerNormBackward,
24};
25use axonml_autograd::grad_fn::GradFn;
26use axonml_autograd::no_grad::is_grad_enabled;
27use axonml_tensor::Tensor;
28use parking_lot::RwLock;
29
30use crate::init::{ones, zeros};
31use crate::module::Module;
32use crate::parameter::Parameter;
33
34pub struct BatchNorm1d {
46 pub weight: Parameter,
48 pub bias: Parameter,
50 running_mean: RwLock<Tensor<f32>>,
52 running_var: RwLock<Tensor<f32>>,
54 num_features: usize,
56 eps: f32,
58 momentum: f32,
60 track_running_stats: bool,
62 training: AtomicBool,
64}
65
66impl BatchNorm1d {
67 pub fn new(num_features: usize) -> Self {
69 Self::with_options(num_features, 1e-5, 0.1, true)
70 }
71
72 pub fn with_options(
74 num_features: usize,
75 eps: f32,
76 momentum: f32,
77 track_running_stats: bool,
78 ) -> Self {
79 Self {
80 weight: Parameter::named("weight", ones(&[num_features]), true),
81 bias: Parameter::named("bias", zeros(&[num_features]), true),
82 running_mean: RwLock::new(zeros(&[num_features])),
83 running_var: RwLock::new(ones(&[num_features])),
84 num_features,
85 eps,
86 momentum,
87 track_running_stats,
88 training: AtomicBool::new(true),
89 }
90 }
91
92 pub fn num_features(&self) -> usize {
94 self.num_features
95 }
96}
97
98impl Module for BatchNorm1d {
99 fn forward(&self, input: &Variable) -> Variable {
100 let input_data = input.data();
101 let shape = input_data.shape().to_vec();
102 let batch_size = shape[0];
103 let num_features = shape[1];
104
105 assert_eq!(
107 num_features, self.num_features,
108 "BatchNorm1d: expected {} features, got {}",
109 self.num_features, num_features
110 );
111
112 let is_training = self.training.load(Ordering::Relaxed);
113 let spatial_size: usize = if shape.len() > 2 {
114 shape[2..].iter().product()
115 } else {
116 1
117 };
118
119 #[cfg(feature = "cuda")]
124 if input_data.device().is_gpu() && is_training {
125 let gamma_data = self.weight.data();
126 let beta_data = self.bias.data();
127
128 let gamma_gpu = if !gamma_data.device().is_gpu() {
130 gamma_data
131 .to_device(input_data.device())
132 .unwrap_or(gamma_data)
133 } else {
134 gamma_data
135 };
136 let beta_gpu = if !beta_data.device().is_gpu() {
137 beta_data
138 .to_device(input_data.device())
139 .unwrap_or(beta_data)
140 } else {
141 beta_data
142 };
143
144 if let Some((output_tensor, means, vars)) = input_data.batchnorm_fused(
145 &gamma_gpu,
146 &beta_gpu,
147 self.eps,
148 num_features,
149 spatial_size,
150 ) {
151 if self.track_running_stats {
153 let mut running_mean = self.running_mean.write();
154 let mut running_var = self.running_var.write();
155 let running_mean_vec = running_mean.to_vec();
156 let running_var_vec = running_var.to_vec();
157 let new_mean: Vec<f32> = running_mean_vec
158 .iter()
159 .zip(means.iter())
160 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
161 .collect();
162 let new_var: Vec<f32> = running_var_vec
163 .iter()
164 .zip(vars.iter())
165 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
166 .collect();
167 *running_mean = Tensor::from_vec(new_mean, &[num_features]).expect("tensor creation failed");
168 *running_var = Tensor::from_vec(new_var, &[num_features]).expect("tensor creation failed");
169 }
170
171 let weight_vec = gamma_gpu.to_vec();
172 let requires_grad =
173 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
174 if requires_grad {
175 let weight_var = self.weight.variable();
176 let bias_var = self.bias.variable();
177 let grad_fn = GradFn::new(BatchNorm1dBackward::new(
178 input.grad_fn().cloned(),
179 weight_var.grad_fn().cloned(),
180 bias_var.grad_fn().cloned(),
181 input_data,
182 means,
183 vars,
184 weight_vec,
185 self.eps,
186 self.num_features,
187 ));
188 return Variable::from_operation(output_tensor, grad_fn, true);
189 }
190 return Variable::new(output_tensor, false);
191 }
192 }
193
194 let input_vec = input_data.to_vec();
195 let weight_vec = self.weight.data().to_vec();
196 let bias_vec = self.bias.data().to_vec();
197
198 let mut means = vec![0.0f32; num_features];
199 let mut vars = vec![0.0f32; num_features];
200
201 if is_training {
202 for c in 0..num_features {
204 let mut sum = 0.0f32;
205 for b in 0..batch_size {
206 for s in 0..spatial_size {
207 let idx = b * num_features * spatial_size + c * spatial_size + s;
208 sum += input_vec[idx];
209 }
210 }
211 means[c] = sum / (batch_size * spatial_size) as f32;
212
213 let mut var_sum = 0.0f32;
214 for b in 0..batch_size {
215 for s in 0..spatial_size {
216 let idx = b * num_features * spatial_size + c * spatial_size + s;
217 let diff = input_vec[idx] - means[c];
218 var_sum += diff * diff;
219 }
220 }
221 vars[c] = var_sum / (batch_size * spatial_size) as f32;
222 }
223
224 if self.track_running_stats {
226 let mut running_mean = self.running_mean.write();
227 let mut running_var = self.running_var.write();
228 let running_mean_vec = running_mean.to_vec();
229 let running_var_vec = running_var.to_vec();
230
231 let new_mean: Vec<f32> = running_mean_vec
232 .iter()
233 .zip(means.iter())
234 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
235 .collect();
236 let new_var: Vec<f32> = running_var_vec
237 .iter()
238 .zip(vars.iter())
239 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
240 .collect();
241
242 *running_mean = Tensor::from_vec(new_mean, &[num_features]).expect("tensor creation failed");
243 *running_var = Tensor::from_vec(new_var, &[num_features]).expect("tensor creation failed");
244 }
245 } else {
246 means = self.running_mean.read().to_vec();
248 vars = self.running_var.read().to_vec();
249 }
250
251 let mut output_vec = vec![0.0f32; input_vec.len()];
253 for b in 0..batch_size {
254 for c in 0..num_features {
255 for s in 0..spatial_size {
256 let idx = b * num_features * spatial_size + c * spatial_size + s;
257 let normalized = (input_vec[idx] - means[c]) / (vars[c] + self.eps).sqrt();
258 output_vec[idx] = normalized * weight_vec[c] + bias_vec[c];
259 }
260 }
261 }
262
263 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
264
265 let requires_grad =
266 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
267 if requires_grad {
268 let weight_var = self.weight.variable();
269 let bias_var = self.bias.variable();
270
271 let grad_fn = GradFn::new(BatchNorm1dBackward::new(
272 input.grad_fn().cloned(),
273 weight_var.grad_fn().cloned(),
274 bias_var.grad_fn().cloned(),
275 input_data,
276 means.clone(),
277 vars.clone(),
278 weight_vec,
279 self.eps,
280 self.num_features,
281 ));
282 Variable::from_operation(output, grad_fn, true)
283 } else {
284 Variable::new(output, false)
285 }
286 }
287
288 fn parameters(&self) -> Vec<Parameter> {
289 vec![self.weight.clone(), self.bias.clone()]
290 }
291
292 fn named_parameters(&self) -> HashMap<String, Parameter> {
293 let mut params = HashMap::new();
294 params.insert("weight".to_string(), self.weight.clone());
295 params.insert("bias".to_string(), self.bias.clone());
296 params
297 }
298
299 fn set_training(&mut self, training: bool) {
300 self.training.store(training, Ordering::Relaxed);
301 }
302
303 fn is_training(&self) -> bool {
304 self.training.load(Ordering::Relaxed)
305 }
306
307 fn name(&self) -> &'static str {
308 "BatchNorm1d"
309 }
310
311 fn to_device(&self, device: axonml_core::Device) {
312 for param in self.parameters() {
314 param.to_device(device);
315 }
316 if self.track_running_stats {
318 let mut rm = self.running_mean.write();
319 if let Ok(moved) = rm.to_device(device) {
320 *rm = moved;
321 }
322 let mut rv = self.running_var.write();
323 if let Ok(moved) = rv.to_device(device) {
324 *rv = moved;
325 }
326 }
327 }
328}
329
330pub struct BatchNorm2d {
340 pub weight: Parameter,
342 pub bias: Parameter,
344 running_mean: RwLock<Tensor<f32>>,
346 running_var: RwLock<Tensor<f32>>,
348 num_features: usize,
350 eps: f32,
352 momentum: f32,
354 training: AtomicBool,
356}
357
358impl BatchNorm2d {
359 pub fn new(num_features: usize) -> Self {
361 Self::with_options(num_features, 1e-5, 0.1)
362 }
363
364 pub fn with_options(num_features: usize, eps: f32, momentum: f32) -> Self {
366 Self {
367 weight: Parameter::named("weight", ones(&[num_features]), true),
368 bias: Parameter::named("bias", zeros(&[num_features]), true),
369 running_mean: RwLock::new(zeros(&[num_features])),
370 running_var: RwLock::new(ones(&[num_features])),
371 num_features,
372 eps,
373 momentum,
374 training: AtomicBool::new(true),
375 }
376 }
377
378 pub fn num_features(&self) -> usize {
380 self.num_features
381 }
382}
383
384impl Module for BatchNorm2d {
385 fn forward(&self, input: &Variable) -> Variable {
386 let input_data = input.data();
387 let shape = input_data.shape().to_vec();
388 let batch_size = shape[0];
389 let channels = shape[1];
390 let height = shape[2];
391 let width = shape[3];
392 let spatial_size = height * width;
393
394 assert_eq!(
396 channels, self.num_features,
397 "BatchNorm2d: expected {} channels, got {}",
398 self.num_features, channels
399 );
400
401 let is_training = self.training.load(Ordering::Relaxed);
402
403 #[cfg(feature = "cuda")]
405 if input_data.device().is_gpu() && is_training {
406 let gamma_data = self.weight.data();
407 let beta_data = self.bias.data();
408
409 let gamma_gpu = if !gamma_data.device().is_gpu() {
411 gamma_data
412 .to_device(input_data.device())
413 .unwrap_or(gamma_data)
414 } else {
415 gamma_data
416 };
417 let beta_gpu = if !beta_data.device().is_gpu() {
418 beta_data
419 .to_device(input_data.device())
420 .unwrap_or(beta_data)
421 } else {
422 beta_data
423 };
424
425 if let Some((output_tensor, means, vars)) =
426 input_data.batchnorm_fused(&gamma_gpu, &beta_gpu, self.eps, channels, spatial_size)
427 {
428 let mut running_mean = self.running_mean.write();
430 let mut running_var = self.running_var.write();
431 let running_mean_vec = running_mean.to_vec();
432 let running_var_vec = running_var.to_vec();
433 let new_mean: Vec<f32> = running_mean_vec
434 .iter()
435 .zip(means.iter())
436 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
437 .collect();
438 let new_var: Vec<f32> = running_var_vec
439 .iter()
440 .zip(vars.iter())
441 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
442 .collect();
443 *running_mean = Tensor::from_vec(new_mean, &[channels]).expect("tensor creation failed");
444 *running_var = Tensor::from_vec(new_var, &[channels]).expect("tensor creation failed");
445
446 let weight_vec = gamma_gpu.to_vec();
447 let requires_grad =
448 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
449 if requires_grad {
450 let weight_var = self.weight.variable();
451 let bias_var = self.bias.variable();
452 let grad_fn = GradFn::new(BatchNorm2dBackward::new(
453 input.grad_fn().cloned(),
454 weight_var.grad_fn().cloned(),
455 bias_var.grad_fn().cloned(),
456 input_data,
457 means,
458 vars,
459 weight_vec,
460 self.eps,
461 self.num_features,
462 ));
463 return Variable::from_operation(output_tensor, grad_fn, true);
464 }
465 return Variable::new(output_tensor, false);
466 }
467 }
468
469 let input_vec = input_data.to_vec();
471 let weight_vec = self.weight.data().to_vec();
472 let bias_vec = self.bias.data().to_vec();
473
474 let mut means = vec![0.0f32; channels];
475 let mut vars = vec![0.0f32; channels];
476
477 if is_training {
478 let n_per_channel = (batch_size * spatial_size) as f32;
479 for c in 0..channels {
480 let mut sum = 0.0f32;
481 let mut sum_sq = 0.0f32;
482 for b in 0..batch_size {
483 let base = b * channels * spatial_size + c * spatial_size;
484 for s in 0..spatial_size {
485 let val = input_vec[base + s];
486 sum += val;
487 sum_sq += val * val;
488 }
489 }
490 means[c] = sum / n_per_channel;
491 vars[c] = sum_sq / n_per_channel - means[c] * means[c];
492 }
493
494 let mut running_mean = self.running_mean.write();
496 let mut running_var = self.running_var.write();
497 let running_mean_vec = running_mean.to_vec();
498 let running_var_vec = running_var.to_vec();
499
500 let new_mean: Vec<f32> = running_mean_vec
501 .iter()
502 .zip(means.iter())
503 .map(|(&rm, &m)| (1.0 - self.momentum) * rm + self.momentum * m)
504 .collect();
505 let new_var: Vec<f32> = running_var_vec
506 .iter()
507 .zip(vars.iter())
508 .map(|(&rv, &v)| (1.0 - self.momentum) * rv + self.momentum * v)
509 .collect();
510
511 *running_mean = Tensor::from_vec(new_mean, &[channels]).expect("tensor creation failed");
512 *running_var = Tensor::from_vec(new_var, &[channels]).expect("tensor creation failed");
513 } else {
514 means = self.running_mean.read().to_vec();
515 vars = self.running_var.read().to_vec();
516 }
517
518 let total = input_vec.len();
520 let mut output_vec = vec![0.0f32; total];
521
522 let inv_stds: Vec<f32> = vars.iter().map(|v| 1.0 / (v + self.eps).sqrt()).collect();
524
525 for i in 0..total {
526 let c = (i / spatial_size) % channels;
527 output_vec[i] = (input_vec[i] - means[c]) * inv_stds[c] * weight_vec[c] + bias_vec[c];
528 }
529
530 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
531
532 let requires_grad =
533 (input.requires_grad() || self.weight.requires_grad()) && is_grad_enabled();
534 if requires_grad {
535 let weight_var = self.weight.variable();
536 let bias_var = self.bias.variable();
537
538 let grad_fn = GradFn::new(BatchNorm2dBackward::new(
539 input.grad_fn().cloned(),
540 weight_var.grad_fn().cloned(),
541 bias_var.grad_fn().cloned(),
542 input_data,
543 means.clone(),
544 vars.clone(),
545 weight_vec,
546 self.eps,
547 self.num_features,
548 ));
549 Variable::from_operation(output, grad_fn, true)
550 } else {
551 Variable::new(output, false)
552 }
553 }
554
555 fn parameters(&self) -> Vec<Parameter> {
556 vec![self.weight.clone(), self.bias.clone()]
557 }
558
559 fn named_parameters(&self) -> HashMap<String, Parameter> {
560 let mut params = HashMap::new();
561 params.insert("weight".to_string(), self.weight.clone());
562 params.insert("bias".to_string(), self.bias.clone());
563 params
564 }
565
566 fn set_training(&mut self, training: bool) {
567 self.training.store(training, Ordering::Relaxed);
568 }
569
570 fn is_training(&self) -> bool {
571 self.training.load(Ordering::Relaxed)
572 }
573
574 fn name(&self) -> &'static str {
575 "BatchNorm2d"
576 }
577
578 fn to_device(&self, device: axonml_core::Device) {
579 for param in self.parameters() {
580 param.to_device(device);
581 }
582 let mut rm = self.running_mean.write();
584 if let Ok(moved) = rm.to_device(device) {
585 *rm = moved;
586 }
587 let mut rv = self.running_var.write();
588 if let Ok(moved) = rv.to_device(device) {
589 *rv = moved;
590 }
591 }
592}
593
594pub struct LayerNorm {
604 pub weight: Parameter,
606 pub bias: Parameter,
608 normalized_shape: Vec<usize>,
610 eps: f32,
612}
613
614impl LayerNorm {
615 pub fn new(normalized_shape: Vec<usize>) -> Self {
617 Self::with_eps(normalized_shape, 1e-5)
618 }
619
620 pub fn single(size: usize) -> Self {
622 Self::new(vec![size])
623 }
624
625 pub fn with_eps(normalized_shape: Vec<usize>, eps: f32) -> Self {
627 let numel: usize = normalized_shape.iter().product();
628 Self {
629 weight: Parameter::named("weight", ones(&[numel]), true),
630 bias: Parameter::named("bias", zeros(&[numel]), true),
631 normalized_shape,
632 eps,
633 }
634 }
635}
636
637impl Module for LayerNorm {
638 fn forward(&self, input: &Variable) -> Variable {
639 let input_data = input.data();
640 let shape = input_data.shape().to_vec();
641 let norm_size: usize = self.normalized_shape.iter().product();
642 let total_len = input_data.numel();
643 let num_rows = total_len / norm_size;
644
645 #[cfg(feature = "cuda")]
647 if input_data.device().is_gpu() {
648 let weight_data = self.weight.data();
650 let weight_gpu = if weight_data.device().is_gpu() {
651 weight_data.clone()
652 } else {
653 weight_data.to_device(input_data.device().clone()).unwrap()
654 };
655 let bias_data = self.bias.data();
656 let bias_gpu = if bias_data.device().is_gpu() {
657 bias_data.clone()
658 } else {
659 bias_data.to_device(input_data.device().clone()).unwrap()
660 };
661
662 let output = input_data
663 .layer_norm_cuda(&weight_gpu, &bias_gpu, norm_size, self.eps)
664 .expect("CUDA LayerNorm failed");
665
666 let requires_grad = input.requires_grad() && is_grad_enabled();
667 return if requires_grad {
668 let grad_fn = GradFn::new(LayerNormBackward::new(
669 input.grad_fn().cloned(),
670 self.weight.variable().grad_fn().cloned(),
671 self.bias.variable().grad_fn().cloned(),
672 input_data.clone(),
673 self.weight.data().clone(),
674 self.normalized_shape.clone(),
675 self.eps,
676 ));
677 Variable::from_operation(output, grad_fn, true)
678 } else {
679 Variable::from_tensor(output)
680 };
681 }
682
683 let input_vec = input_data.to_vec();
685 let weight_vec = self.weight.data().to_vec();
686 let bias_vec = self.bias.data().to_vec();
687
688 let mut output_vec = vec![0.0f32; input_vec.len()];
689
690 for b in 0..num_rows {
691 let start = b * norm_size;
692 let end = start + norm_size;
693 let slice = &input_vec[start..end];
694
695 let mean: f32 = slice.iter().sum::<f32>() / norm_size as f32;
696 let var: f32 = slice.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / norm_size as f32;
697
698 for i in 0..norm_size {
699 let normalized = (slice[i] - mean) / (var + self.eps).sqrt();
700 output_vec[start + i] = normalized * weight_vec[i] + bias_vec[i];
701 }
702 }
703
704 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
705 let requires_grad = input.requires_grad() && is_grad_enabled();
706
707 if requires_grad {
708 let grad_fn = GradFn::new(LayerNormBackward::new(
709 input.grad_fn().cloned(),
710 self.weight.variable().grad_fn().cloned(),
711 self.bias.variable().grad_fn().cloned(),
712 input_data.clone(),
713 self.weight.data().clone(),
714 self.normalized_shape.clone(),
715 self.eps,
716 ));
717 Variable::from_operation(output, grad_fn, true)
718 } else {
719 Variable::from_tensor(output)
720 }
721 }
722
723 fn parameters(&self) -> Vec<Parameter> {
724 vec![self.weight.clone(), self.bias.clone()]
725 }
726
727 fn named_parameters(&self) -> HashMap<String, Parameter> {
728 let mut params = HashMap::new();
729 params.insert("weight".to_string(), self.weight.clone());
730 params.insert("bias".to_string(), self.bias.clone());
731 params
732 }
733
734 fn name(&self) -> &'static str {
735 "LayerNorm"
736 }
737}
738
739pub struct GroupNorm {
752 pub weight: Parameter,
754 pub bias: Parameter,
756 num_groups: usize,
758 num_channels: usize,
760 eps: f32,
762 affine: bool,
764}
765
766impl GroupNorm {
767 pub fn new(num_groups: usize, num_channels: usize) -> Self {
773 Self::with_options(num_groups, num_channels, 1e-5, true)
774 }
775
776 pub fn with_options(num_groups: usize, num_channels: usize, eps: f32, affine: bool) -> Self {
778 assert!(
779 num_channels % num_groups == 0,
780 "num_channels ({}) must be divisible by num_groups ({})",
781 num_channels,
782 num_groups
783 );
784
785 Self {
786 weight: Parameter::named("weight", ones(&[num_channels]), affine),
787 bias: Parameter::named("bias", zeros(&[num_channels]), affine),
788 num_groups,
789 num_channels,
790 eps,
791 affine,
792 }
793 }
794}
795
796impl Module for GroupNorm {
797 fn forward(&self, input: &Variable) -> Variable {
798 let input_data = input.data();
799 let shape = input_data.shape().to_vec();
800 let batch_size = shape[0];
801 let channels = shape[1];
802 let spatial_size: usize = shape[2..].iter().product();
803
804 assert_eq!(
805 channels, self.num_channels,
806 "GroupNorm: expected {} channels, got {}",
807 self.num_channels, channels
808 );
809
810 let input_vec = input_data.to_vec();
811 let channels_per_group = channels / self.num_groups;
812
813 let mut output_vec = vec![0.0f32; input_vec.len()];
814
815 for b in 0..batch_size {
816 for g in 0..self.num_groups {
817 let mut sum = 0.0f32;
819 let group_size = channels_per_group * spatial_size;
820
821 for c in 0..channels_per_group {
822 let channel_idx = g * channels_per_group + c;
823 for s in 0..spatial_size {
824 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
825 sum += input_vec[idx];
826 }
827 }
828 let mean = sum / group_size as f32;
829
830 let mut var_sum = 0.0f32;
831 for c in 0..channels_per_group {
832 let channel_idx = g * channels_per_group + c;
833 for s in 0..spatial_size {
834 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
835 let diff = input_vec[idx] - mean;
836 var_sum += diff * diff;
837 }
838 }
839 let var = var_sum / group_size as f32;
840
841 let std_inv = 1.0 / (var + self.eps).sqrt();
843 for c in 0..channels_per_group {
844 let channel_idx = g * channels_per_group + c;
845 let weight = if self.affine {
846 self.weight.data().to_vec()[channel_idx]
847 } else {
848 1.0
849 };
850 let bias = if self.affine {
851 self.bias.data().to_vec()[channel_idx]
852 } else {
853 0.0
854 };
855
856 for s in 0..spatial_size {
857 let idx = b * channels * spatial_size + channel_idx * spatial_size + s;
858 let normalized = (input_vec[idx] - mean) * std_inv;
859 output_vec[idx] = normalized * weight + bias;
860 }
861 }
862 }
863 }
864
865 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
866 let requires_grad = input.requires_grad() && is_grad_enabled();
867
868 if requires_grad && self.affine {
869 let grad_fn = GradFn::new(GroupNormBackward::new(
870 input.grad_fn().cloned(),
871 self.weight.variable().grad_fn().cloned(),
872 self.bias.variable().grad_fn().cloned(),
873 input_data.clone(),
874 self.weight.data().clone(),
875 self.num_groups,
876 self.eps,
877 ));
878 Variable::from_operation(output, grad_fn, true)
879 } else {
880 Variable::from_tensor(output)
881 }
882 }
883
884 fn parameters(&self) -> Vec<Parameter> {
885 if self.affine {
886 vec![self.weight.clone(), self.bias.clone()]
887 } else {
888 vec![]
889 }
890 }
891
892 fn named_parameters(&self) -> HashMap<String, Parameter> {
893 if self.affine {
894 let mut params = HashMap::new();
895 params.insert("weight".to_string(), self.weight.clone());
896 params.insert("bias".to_string(), self.bias.clone());
897 params
898 } else {
899 HashMap::new()
900 }
901 }
902
903 fn name(&self) -> &'static str {
904 "GroupNorm"
905 }
906}
907
908pub struct InstanceNorm2d {
921 pub weight: Parameter,
923 pub bias: Parameter,
925 num_features: usize,
927 eps: f32,
929 affine: bool,
931}
932
933impl InstanceNorm2d {
934 pub fn new(num_features: usize) -> Self {
936 Self::with_options(num_features, 1e-5, false)
937 }
938
939 pub fn with_affine(num_features: usize) -> Self {
941 Self::with_options(num_features, 1e-5, true)
942 }
943
944 pub fn with_options(num_features: usize, eps: f32, affine: bool) -> Self {
946 Self {
947 weight: Parameter::named("weight", ones(&[num_features]), affine),
948 bias: Parameter::named("bias", zeros(&[num_features]), affine),
949 num_features,
950 eps,
951 affine,
952 }
953 }
954}
955
956impl Module for InstanceNorm2d {
957 fn forward(&self, input: &Variable) -> Variable {
958 let input_data = input.data();
959 let shape = input_data.shape().to_vec();
960
961 assert!(
962 shape.len() == 4,
963 "InstanceNorm2d expects 4D input (N, C, H, W)"
964 );
965
966 let batch_size = shape[0];
967 let channels = shape[1];
968 let height = shape[2];
969 let width = shape[3];
970 let spatial_size = height * width;
971
972 assert_eq!(
973 channels, self.num_features,
974 "InstanceNorm2d: expected {} channels, got {}",
975 self.num_features, channels
976 );
977
978 let input_vec = input_data.to_vec();
979 let mut output_vec = vec![0.0f32; input_vec.len()];
980
981 for b in 0..batch_size {
982 for c in 0..channels {
983 let mut sum = 0.0f32;
985 for s in 0..spatial_size {
986 let idx = b * channels * spatial_size + c * spatial_size + s;
987 sum += input_vec[idx];
988 }
989 let mean = sum / spatial_size as f32;
990
991 let mut var_sum = 0.0f32;
993 for s in 0..spatial_size {
994 let idx = b * channels * spatial_size + c * spatial_size + s;
995 let diff = input_vec[idx] - mean;
996 var_sum += diff * diff;
997 }
998 let var = var_sum / spatial_size as f32;
999
1000 let std_inv = 1.0 / (var + self.eps).sqrt();
1002 let weight = if self.affine {
1003 self.weight.data().to_vec()[c]
1004 } else {
1005 1.0
1006 };
1007 let bias = if self.affine {
1008 self.bias.data().to_vec()[c]
1009 } else {
1010 0.0
1011 };
1012
1013 for s in 0..spatial_size {
1014 let idx = b * channels * spatial_size + c * spatial_size + s;
1015 let normalized = (input_vec[idx] - mean) * std_inv;
1016 output_vec[idx] = normalized * weight + bias;
1017 }
1018 }
1019 }
1020
1021 let output = Tensor::from_vec(output_vec, &shape).expect("tensor creation failed");
1022 let requires_grad = input.requires_grad() && is_grad_enabled();
1023
1024 if requires_grad {
1025 let grad_fn = GradFn::new(InstanceNorm2dBackward::new(
1026 input.grad_fn().cloned(),
1027 if self.affine {
1028 self.weight.variable().grad_fn().cloned()
1029 } else {
1030 None
1031 },
1032 if self.affine {
1033 self.bias.variable().grad_fn().cloned()
1034 } else {
1035 None
1036 },
1037 input_data.clone(),
1038 self.weight.data().clone(),
1039 self.eps,
1040 self.affine,
1041 ));
1042 Variable::from_operation(output, grad_fn, true)
1043 } else {
1044 Variable::from_tensor(output)
1045 }
1046 }
1047
1048 fn parameters(&self) -> Vec<Parameter> {
1049 if self.affine {
1050 vec![self.weight.clone(), self.bias.clone()]
1051 } else {
1052 vec![]
1053 }
1054 }
1055
1056 fn named_parameters(&self) -> HashMap<String, Parameter> {
1057 if self.affine {
1058 let mut params = HashMap::new();
1059 params.insert("weight".to_string(), self.weight.clone());
1060 params.insert("bias".to_string(), self.bias.clone());
1061 params
1062 } else {
1063 HashMap::new()
1064 }
1065 }
1066
1067 fn name(&self) -> &'static str {
1068 "InstanceNorm2d"
1069 }
1070}
1071
1072#[cfg(test)]
1077mod tests {
1078 use super::*;
1079
1080 #[test]
1081 fn test_batchnorm1d() {
1082 let bn = BatchNorm1d::new(3);
1083 let input = Variable::new(
1084 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).expect("tensor creation failed"),
1085 false,
1086 );
1087 let output = bn.forward(&input);
1088 assert_eq!(output.shape(), vec![2, 3]);
1089 }
1090
1091 #[test]
1092 fn test_batchnorm2d() {
1093 let bn = BatchNorm2d::new(2);
1094 let input = Variable::new(
1095 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).expect("tensor creation failed"),
1096 false,
1097 );
1098 let output = bn.forward(&input);
1099 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1100 }
1101
1102 #[test]
1103 fn test_layernorm() {
1104 let ln = LayerNorm::single(4);
1105 let input = Variable::new(
1106 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4]).expect("tensor creation failed"),
1107 false,
1108 );
1109 let output = ln.forward(&input);
1110 assert_eq!(output.shape(), vec![2, 4]);
1111 }
1112
1113 #[test]
1114 fn test_batchnorm_parameters() {
1115 let bn = BatchNorm1d::new(10);
1116 assert_eq!(bn.parameters().len(), 2);
1117 assert_eq!(bn.num_parameters(), 20); }
1119
1120 #[test]
1121 fn test_groupnorm() {
1122 let gn = GroupNorm::new(2, 4); let input = Variable::new(
1124 Tensor::from_vec(vec![1.0; 32], &[2, 4, 2, 2]).expect("tensor creation failed"),
1125 false,
1126 );
1127 let output = gn.forward(&input);
1128 assert_eq!(output.shape(), vec![2, 4, 2, 2]);
1129 }
1130
1131 #[test]
1132 fn test_groupnorm_normalization() {
1133 let gn = GroupNorm::with_options(2, 4, 1e-5, false); let input = Variable::new(
1135 Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 4, 1, 2]).expect("tensor creation failed"),
1136 false,
1137 );
1138 let output = gn.forward(&input);
1139 let out_vec = output.data().to_vec();
1141 let group1_mean: f32 = out_vec[0..4].iter().sum::<f32>() / 4.0;
1143 let group2_mean: f32 = out_vec[4..8].iter().sum::<f32>() / 4.0;
1144 assert!(group1_mean.abs() < 1e-5);
1145 assert!(group2_mean.abs() < 1e-5);
1146 }
1147
1148 #[test]
1149 fn test_instancenorm2d() {
1150 let inn = InstanceNorm2d::new(2);
1151 let input = Variable::new(
1152 Tensor::from_vec(vec![1.0; 32], &[2, 2, 2, 4]).expect("tensor creation failed"),
1153 false,
1154 );
1155 let output = inn.forward(&input);
1156 assert_eq!(output.shape(), vec![2, 2, 2, 4]);
1157 }
1158
1159 #[test]
1160 fn test_instancenorm2d_with_affine() {
1161 let inn = InstanceNorm2d::with_affine(4);
1162 let input = Variable::new(
1163 Tensor::from_vec(vec![1.0; 64], &[1, 4, 4, 4]).expect("tensor creation failed"),
1164 false,
1165 );
1166 let output = inn.forward(&input);
1167 assert_eq!(output.shape(), vec![1, 4, 4, 4]);
1168 assert_eq!(inn.parameters().len(), 2);
1169 }
1170}