1use std::any::TypeId;
26use std::sync::{Arc, Mutex};
27
28use ferrotorch_core::autograd::no_grad::is_grad_enabled;
29use ferrotorch_core::gpu_dispatch::gpu_backend;
30use ferrotorch_core::tensor::GradFn;
31use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
32
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36#[inline]
37fn is_f32<T: Float>() -> bool {
38 TypeId::of::<T>() == TypeId::of::<f32>()
39}
40
41#[inline]
42fn is_f64<T: Float>() -> bool {
43 TypeId::of::<T>() == TypeId::of::<f64>()
44}
45
46#[inline]
48fn zero<T: Float>() -> T {
49 <T as num_traits::Zero>::zero()
50}
51
52fn gpu_handle_to_f32(
58 backend: &dyn ferrotorch_core::gpu_dispatch::GpuBackend,
59 handle: &ferrotorch_core::gpu_dispatch::GpuBufferHandle,
60) -> FerrotorchResult<Vec<f32>> {
61 let bytes = backend.gpu_to_cpu(handle)?;
62 Ok(bytes
63 .chunks_exact(4)
64 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
65 .collect())
66}
67
68#[allow(clippy::too_many_arguments)]
85fn batch_norm_gpu_forward<T: Float>(
86 input: &Tensor<T>,
87 weight: Option<&Tensor<T>>,
88 bias: Option<&Tensor<T>>,
89 running_mean: &Mutex<Vec<f64>>,
90 running_var: &Mutex<Vec<f64>>,
91 num_batches_tracked: &Mutex<usize>,
92 momentum: f64,
93 eps: f64,
94 channels: usize,
95 spatial: usize,
96 is_training: bool,
97) -> FerrotorchResult<Option<ferrotorch_core::gpu_dispatch::GpuBufferHandle>> {
98 let Some(backend) = gpu_backend() else {
99 return Ok(None);
100 };
101 let batch = input.numel() / (channels * spatial.max(1));
102
103 let weight_dev;
107 let bias_dev;
108 let (w_handle, b_handle) = match (weight, bias) {
109 (Some(w), Some(b)) => (w.gpu_handle()?, b.gpu_handle()?),
110 _ => {
111 weight_dev = ferrotorch_core::creation::ones::<T>(&[channels])?.to(input.device())?;
112 bias_dev = ferrotorch_core::creation::zeros::<T>(&[channels])?.to(input.device())?;
113 (weight_dev.gpu_handle()?, bias_dev.gpu_handle()?)
114 }
115 };
116
117 let rm_snapshot: Vec<f32> = running_mean
121 .lock()
122 .unwrap()
123 .iter()
124 .map(|v| *v as f32)
125 .collect();
126 let rv_snapshot: Vec<f32> = running_var
127 .lock()
128 .unwrap()
129 .iter()
130 .map(|v| *v as f32)
131 .collect();
132 let mean_in = ferrotorch_core::creation::from_slice::<f32>(&rm_snapshot, &[channels])?
133 .to(input.device())?;
134 let var_in = ferrotorch_core::creation::from_slice::<f32>(&rv_snapshot, &[channels])?
135 .to(input.device())?;
136
137 let (out_handle, mean_out, var_out) = backend.batch_norm_f32(
138 input.gpu_handle()?,
139 w_handle,
140 b_handle,
141 mean_in.gpu_handle()?,
142 var_in.gpu_handle()?,
143 batch,
144 channels,
145 spatial.max(1),
146 eps as f32,
147 is_training,
148 )?;
149
150 if is_training {
151 let batch_mean = gpu_handle_to_f32(backend, &mean_out)?;
156 let batch_var = gpu_handle_to_f32(backend, &var_out)?;
157 let count = batch * spatial.max(1);
158 let bessel = if count > 1 {
159 count as f64 / (count as f64 - 1.0)
160 } else {
161 1.0
162 };
163 let mut rm = running_mean.lock().unwrap();
164 let mut rv = running_var.lock().unwrap();
165 let mut nbt = num_batches_tracked.lock().unwrap();
166 *nbt += 1;
167 for c in 0..channels {
168 let bm = batch_mean[c] as f64;
169 let bv = batch_var[c] as f64;
170 rm[c] = (1.0 - momentum) * rm[c] + momentum * bm;
171 rv[c] = (1.0 - momentum) * rv[c] + momentum * bv * bessel;
172 }
173 }
174
175 Ok(Some(out_handle))
176}
177
178#[allow(clippy::too_many_arguments)]
198#[allow(
199 clippy::fn_params_excessive_bools,
200 reason = "the four flags (training / affine / want_weight_grad / want_bias_grad) each \
201 control a distinct branch of the BatchNorm backward contract and mirror \
202 PyTorch's train arg + grad_input_mask; no meaningful struct grouping exists"
203)]
204fn batch_norm_gpu_backward<T: Float>(
205 input: &Tensor<T>,
206 grad_output: &Tensor<T>,
207 weight_buf: &Tensor<T>,
208 mean: &[f64],
209 var: &[f64],
210 batch: usize,
211 channels: usize,
212 spatial: usize,
213 eps: f64,
214 training: bool,
215 affine: bool,
216 want_weight_grad: bool,
217 want_bias_grad: bool,
218) -> FerrotorchResult<Option<Vec<Option<Tensor<T>>>>> {
219 let Some(backend) = gpu_backend() else {
220 return Ok(None);
221 };
222 let mean_f32: Vec<f32> = mean.iter().map(|v| *v as f32).collect();
223 let var_f32: Vec<f32> = var.iter().map(|v| *v as f32).collect();
224 let mean_dev =
225 ferrotorch_core::creation::from_slice::<f32>(&mean_f32, &[channels])?.to(input.device())?;
226 let var_dev =
227 ferrotorch_core::creation::from_slice::<f32>(&var_f32, &[channels])?.to(input.device())?;
228
229 let (gi_h, gw_h, gb_h) = backend.batch_norm_backward_f32(
230 input.gpu_handle()?,
231 grad_output.gpu_handle()?,
232 weight_buf.gpu_handle()?,
233 mean_dev.gpu_handle()?,
234 var_dev.gpu_handle()?,
235 batch,
236 channels,
237 spatial.max(1),
238 eps as f32,
239 training,
240 )?;
241
242 let grad_input = Tensor::from_storage(TensorStorage::gpu(gi_h), input.shape().to_vec(), false)?;
243 let grad_weight = if affine && want_weight_grad {
244 Some(Tensor::from_storage(
245 TensorStorage::gpu(gw_h),
246 vec![channels],
247 false,
248 )?)
249 } else {
250 None
251 };
252 let grad_bias = if affine && want_bias_grad {
253 Some(Tensor::from_storage(
254 TensorStorage::gpu(gb_h),
255 vec![channels],
256 false,
257 )?)
258 } else {
259 None
260 };
261
262 if affine {
271 Ok(Some(vec![Some(grad_input), grad_weight, grad_bias]))
272 } else {
273 Ok(Some(vec![Some(grad_input)]))
274 }
275}
276
277#[allow(clippy::too_many_arguments)]
291#[allow(
292 clippy::fn_params_excessive_bools,
293 reason = "affine / want_weight_grad / want_bias_grad each gate a distinct branch of the \
294 InstanceNorm backward contract and mirror PyTorch's grad_input_mask; no \
295 meaningful struct grouping exists"
296)]
297fn instance_norm_gpu_backward<T: Float>(
298 input: &Tensor<T>,
299 grad_output: &Tensor<T>,
300 weight: &Tensor<T>,
301 batch: usize,
302 channels: usize,
303 spatial: usize,
304 eps: f64,
305 affine: bool,
306 want_weight_grad: bool,
307 want_bias_grad: bool,
308) -> FerrotorchResult<Option<Vec<Option<Tensor<T>>>>> {
309 let Some(backend) = gpu_backend() else {
310 return Ok(None);
311 };
312 let bc = batch * channels;
313
314 let weight_host: Vec<f32> = if affine {
317 let w = weight.data_vec()?;
321 let mut tiled = vec![0.0f32; bc];
322 for b in 0..batch {
323 for c in 0..channels {
324 tiled[b * channels + c] = w[c].to_f32().unwrap();
325 }
326 }
327 tiled
328 } else {
329 vec![1.0f32; bc]
330 };
331 let weight_dev =
332 ferrotorch_core::creation::from_slice::<f32>(&weight_host, &[bc])?.to(input.device())?;
333 let stat_dummy = ferrotorch_core::creation::zeros::<f32>(&[bc])?.to(input.device())?;
336
337 let (gi_h, gw_h, gb_h) = backend.batch_norm_backward_f32(
338 input.gpu_handle()?,
339 grad_output.gpu_handle()?,
340 weight_dev.gpu_handle()?,
341 stat_dummy.gpu_handle()?,
342 stat_dummy.gpu_handle()?,
343 1, bc, spatial.max(1),
346 eps as f32,
347 true, )?;
349
350 let grad_input = Tensor::from_storage(TensorStorage::gpu(gi_h), input.shape().to_vec(), false)?;
351
352 let reduce =
355 |handle: ferrotorch_core::gpu_dispatch::GpuBufferHandle| -> FerrotorchResult<Tensor<T>> {
356 let summed = backend.sum_axis_f32(&handle, &[batch, channels], 0)?;
357 Tensor::from_storage(TensorStorage::gpu(summed), vec![channels], false)
358 };
359 let grad_weight = if affine && want_weight_grad {
360 Some(reduce(gw_h)?)
361 } else {
362 None
363 };
364 let grad_bias = if affine && want_bias_grad {
365 Some(reduce(gb_h)?)
366 } else {
367 None
368 };
369
370 Ok(Some(vec![Some(grad_input), grad_weight, grad_bias]))
371}
372
373#[derive(Debug)]
391pub struct LayerNorm<T: Float> {
392 pub normalized_shape: Vec<usize>,
394 pub eps: f64,
396 pub elementwise_affine: bool,
398 pub weight: Parameter<T>,
400 pub bias: Parameter<T>,
402 training: bool,
403}
404
405impl<T: Float> LayerNorm<T> {
406 pub fn new(
416 normalized_shape: Vec<usize>,
417 eps: f64,
418 elementwise_affine: bool,
419 ) -> FerrotorchResult<Self> {
420 if normalized_shape.is_empty() {
421 return Err(FerrotorchError::InvalidArgument {
422 message: "normalized_shape must not be empty".into(),
423 });
424 }
425
426 let weight = Parameter::ones(&normalized_shape)?;
427 let bias = Parameter::zeros(&normalized_shape)?;
428
429 Ok(Self {
430 normalized_shape,
431 eps,
432 elementwise_affine,
433 weight,
434 bias,
435 training: true,
436 })
437 }
438}
439
440impl<T: Float> Module<T> for LayerNorm<T> {
441 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
442 let shape = input.shape().to_vec();
443 let ndim = shape.len();
444 let norm_ndim = self.normalized_shape.len();
445
446 if ndim < norm_ndim {
447 return Err(FerrotorchError::ShapeMismatch {
448 message: format!(
449 "LayerNorm: input has {} dims but normalized_shape has {} dims",
450 ndim, norm_ndim
451 ),
452 });
453 }
454
455 let last_dims = &shape[ndim - norm_ndim..];
457 if last_dims != self.normalized_shape.as_slice() {
458 return Err(FerrotorchError::ShapeMismatch {
459 message: format!(
460 "LayerNorm: input last dims {:?} don't match normalized_shape {:?}",
461 last_dims, self.normalized_shape
462 ),
463 });
464 }
465
466 let norm_size: usize = self.normalized_shape.iter().product();
467 let batch_size = input.numel() / norm_size;
468
469 if input.is_cuda() && self.elementwise_affine {
471 if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
472 let eps_f32 = self.eps as f32;
473 let handle = backend.layernorm_f32(
474 input.gpu_handle()?,
475 self.weight.tensor().gpu_handle()?,
476 self.bias.tensor().gpu_handle()?,
477 batch_size,
478 norm_size,
479 eps_f32,
480 )?;
481 return if is_grad_enabled() && input.requires_grad() {
482 let grad_fn = Arc::new(LayerNormBackward {
483 input: input.clone(),
484 weight: self.weight.tensor().clone(),
485 bias: self.bias.tensor().clone(),
486 normalized_shape: self.normalized_shape.clone(),
487 eps: self.eps,
488 elementwise_affine: self.elementwise_affine,
489 });
490 Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
491 } else {
492 Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
493 };
494 }
495 }
496
497 if input.is_cuda() {
499 return Err(FerrotorchError::NotImplementedOnCuda {
500 op: "LayerNorm::forward",
501 });
502 }
503 let input_data = input.data()?;
504 let eps_t = T::from(self.eps).unwrap();
505 let n_t = T::from(norm_size).unwrap();
506
507 let weight_data = self.weight.tensor().data()?;
508 let bias_data = self.bias.tensor().data()?;
509
510 let mut output = Vec::with_capacity(input.numel());
511
512 for b in 0..batch_size {
513 let start = b * norm_size;
514 let end = start + norm_size;
515 let slice = &input_data[start..end];
516
517 let mean = slice.iter().copied().fold(zero::<T>(), |a, x| a + x) / n_t;
518 let var = slice.iter().copied().fold(zero::<T>(), |a, x| {
519 let d = x - mean;
520 a + d * d
521 }) / n_t;
522 let inv_std = (var + eps_t).sqrt().recip();
523
524 for (i, &x) in slice.iter().enumerate() {
525 let normed = (x - mean) * inv_std;
526 if self.elementwise_affine {
527 output.push(normed * weight_data[i] + bias_data[i]);
528 } else {
529 output.push(normed);
530 }
531 }
532 }
533
534 let storage = TensorStorage::cpu(output);
535
536 if is_grad_enabled() && input.requires_grad() {
537 let grad_fn = Arc::new(LayerNormBackward {
538 input: input.clone(),
539 weight: self.weight.tensor().clone(),
540 bias: self.bias.tensor().clone(),
541 normalized_shape: self.normalized_shape.clone(),
542 eps: self.eps,
543 elementwise_affine: self.elementwise_affine,
544 });
545 Tensor::from_operation(storage, shape.to_vec(), grad_fn)
546 } else {
547 Tensor::from_storage(storage, shape.to_vec(), false)
548 }
549 }
550
551 fn parameters(&self) -> Vec<&Parameter<T>> {
552 if self.elementwise_affine {
553 vec![&self.weight, &self.bias]
554 } else {
555 vec![]
556 }
557 }
558
559 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
560 if self.elementwise_affine {
561 vec![&mut self.weight, &mut self.bias]
562 } else {
563 vec![]
564 }
565 }
566
567 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
568 if self.elementwise_affine {
569 vec![
570 ("weight".to_string(), &self.weight),
571 ("bias".to_string(), &self.bias),
572 ]
573 } else {
574 vec![]
575 }
576 }
577
578 fn train(&mut self) {
579 self.training = true;
580 }
581
582 fn eval(&mut self) {
583 self.training = false;
584 }
585
586 fn is_training(&self) -> bool {
587 self.training
588 }
589}
590
591#[derive(Debug)]
607struct LayerNormBackward<T: Float> {
608 input: Tensor<T>,
609 weight: Tensor<T>,
610 bias: Tensor<T>,
611 normalized_shape: Vec<usize>,
612 eps: f64,
613 elementwise_affine: bool,
614}
615
616impl<T: Float> GradFn<T> for LayerNormBackward<T> {
617 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
618 let norm_size: usize = self.normalized_shape.iter().product();
619 let batch_size = self.input.numel() / norm_size;
620
621 if self.input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) && self.elementwise_affine {
623 if let Some(backend) = gpu_backend() {
624 let (gi_h, gw_h, gb_h) = if is_f64::<T>() {
625 backend.layernorm_backward_f64(
626 self.input.gpu_handle()?,
627 grad_output.gpu_handle()?,
628 self.weight.gpu_handle()?,
629 batch_size,
630 norm_size,
631 self.eps,
632 )?
633 } else {
634 backend.layernorm_backward_f32(
635 self.input.gpu_handle()?,
636 grad_output.gpu_handle()?,
637 self.weight.gpu_handle()?,
638 batch_size,
639 norm_size,
640 self.eps as f32,
641 )?
642 };
643
644 let grad_input_tensor = Tensor::from_storage(
645 TensorStorage::gpu(gi_h),
646 self.input.shape().to_vec(),
647 false,
648 )?;
649
650 let grad_weight_out = if self.weight.requires_grad() {
651 Some(Tensor::from_storage(
652 TensorStorage::gpu(gw_h),
653 self.normalized_shape.clone(),
654 false,
655 )?)
656 } else {
657 None
658 };
659
660 let grad_bias_out = if self.bias.requires_grad() {
661 Some(Tensor::from_storage(
662 TensorStorage::gpu(gb_h),
663 self.normalized_shape.clone(),
664 false,
665 )?)
666 } else {
667 None
668 };
669
670 return Ok(vec![
671 Some(grad_input_tensor),
672 grad_weight_out,
673 grad_bias_out,
674 ]);
675 }
676 }
677
678 if self.input.is_cuda() {
680 return Err(FerrotorchError::NotImplementedOnCuda {
681 op: "LayerNormBackward",
682 });
683 }
684 let n_t = T::from(norm_size).unwrap();
685 let eps_t = T::from(self.eps).unwrap();
686
687 let input_data = self.input.data()?;
688 let go_data = grad_output.data()?;
689 let weight_data = self.weight.data()?;
690
691 let mut grad_input = vec![zero::<T>(); self.input.numel()];
692 let mut grad_weight = vec![zero::<T>(); norm_size];
693 let mut grad_bias = vec![zero::<T>(); norm_size];
694
695 for b in 0..batch_size {
696 let start = b * norm_size;
697 let end = start + norm_size;
698 let x_slice = &input_data[start..end];
699 let go_slice = &go_data[start..end];
700
701 let mean = x_slice.iter().copied().fold(zero::<T>(), |a, x| a + x) / n_t;
703 let var = x_slice.iter().copied().fold(zero::<T>(), |a, x| {
704 let d = x - mean;
705 a + d * d
706 }) / n_t;
707 let inv_std = (var + eps_t).sqrt().recip();
708
709 let mut dl_dx_hat_sum = zero::<T>();
712 let mut dl_dx_hat_x_hat_sum = zero::<T>();
713
714 for i in 0..norm_size {
715 let x_hat_i = (x_slice[i] - mean) * inv_std;
716 let dl_dx_hat_i = if self.elementwise_affine {
717 go_slice[i] * weight_data[i]
718 } else {
719 go_slice[i]
720 };
721
722 dl_dx_hat_sum += dl_dx_hat_i;
723 dl_dx_hat_x_hat_sum += dl_dx_hat_i * x_hat_i;
724
725 if self.elementwise_affine {
726 grad_weight[i] += go_slice[i] * x_hat_i;
727 grad_bias[i] += go_slice[i];
728 }
729 }
730
731 let dl_dx_hat_mean = dl_dx_hat_sum / n_t;
733 let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / n_t;
734
735 for i in 0..norm_size {
736 let x_hat_i = (x_slice[i] - mean) * inv_std;
737 let dl_dx_hat_i = if self.elementwise_affine {
738 go_slice[i] * weight_data[i]
739 } else {
740 go_slice[i]
741 };
742
743 grad_input[start + i] =
744 inv_std * (dl_dx_hat_i - dl_dx_hat_mean - x_hat_i * dl_dx_hat_x_hat_mean);
745 }
746 }
747
748 let grad_input_tensor = Tensor::from_storage(
749 TensorStorage::cpu(grad_input),
750 self.input.shape().to_vec(),
751 false,
752 )?;
753
754 let grad_weight_out = if self.elementwise_affine && self.weight.requires_grad() {
755 Some(Tensor::from_storage(
756 TensorStorage::cpu(grad_weight),
757 self.normalized_shape.clone(),
758 false,
759 )?)
760 } else {
761 None
762 };
763
764 let grad_bias_out = if self.elementwise_affine && self.bias.requires_grad() {
765 Some(Tensor::from_storage(
766 TensorStorage::cpu(grad_bias),
767 self.normalized_shape.clone(),
768 false,
769 )?)
770 } else {
771 None
772 };
773
774 Ok(vec![
775 Some(grad_input_tensor),
776 grad_weight_out,
777 grad_bias_out,
778 ])
779 }
780
781 fn inputs(&self) -> Vec<&Tensor<T>> {
782 vec![&self.input, &self.weight, &self.bias]
783 }
784
785 fn name(&self) -> &'static str {
786 "LayerNormBackward"
787 }
788}
789
790#[derive(Debug)]
803pub struct GroupNorm<T: Float> {
804 pub num_groups: usize,
806 pub num_channels: usize,
808 pub eps: f64,
810 pub affine: bool,
812 pub weight: Parameter<T>,
814 pub bias: Parameter<T>,
816 training: bool,
817}
818
819impl<T: Float> GroupNorm<T> {
820 pub fn new(
829 num_groups: usize,
830 num_channels: usize,
831 eps: f64,
832 affine: bool,
833 ) -> FerrotorchResult<Self> {
834 if num_groups == 0 {
835 return Err(FerrotorchError::InvalidArgument {
836 message: "num_groups must be positive".into(),
837 });
838 }
839 if num_channels == 0 {
840 return Err(FerrotorchError::InvalidArgument {
841 message: "num_channels must be positive".into(),
842 });
843 }
844 if num_channels % num_groups != 0 {
845 return Err(FerrotorchError::InvalidArgument {
846 message: format!(
847 "num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
848 ),
849 });
850 }
851
852 let weight = Parameter::ones(&[num_channels])?;
853 let bias = Parameter::zeros(&[num_channels])?;
854
855 Ok(Self {
856 num_groups,
857 num_channels,
858 eps,
859 affine,
860 weight,
861 bias,
862 training: true,
863 })
864 }
865}
866
867impl<T: Float> Module<T> for GroupNorm<T> {
868 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
869 let shape = input.shape().to_vec();
870 if shape.len() < 2 {
871 return Err(FerrotorchError::ShapeMismatch {
872 message: format!(
873 "GroupNorm: input must have at least 2 dims [B, C, ...], got {:?}",
874 shape
875 ),
876 });
877 }
878
879 let batch_size = shape[0];
880 let channels = shape[1];
881
882 if channels != self.num_channels {
883 return Err(FerrotorchError::ShapeMismatch {
884 message: format!(
885 "GroupNorm: expected {} channels, got {}",
886 self.num_channels, channels
887 ),
888 });
889 }
890
891 let channels_per_group = channels / self.num_groups;
892 let spatial_size: usize = shape[2..].iter().product();
894 let spatial = spatial_size.max(1);
895 let group_size = channels_per_group * spatial;
896
897 if input.is_cuda() {
904 if let Some(backend) = gpu_backend() {
905 let eps_f32 = self.eps as f32;
906 let handle = backend.group_norm_f32(
907 input.gpu_handle()?,
908 self.weight.tensor().gpu_handle()?,
909 self.bias.tensor().gpu_handle()?,
910 batch_size,
911 channels,
912 self.num_groups,
913 spatial,
914 eps_f32,
915 )?;
916 return if is_grad_enabled() && input.requires_grad() {
917 let grad_fn = Arc::new(GroupNormBackward {
918 input: input.clone(),
919 weight: self.weight.tensor().clone(),
920 bias: self.bias.tensor().clone(),
921 num_groups: self.num_groups,
922 num_channels: self.num_channels,
923 eps: self.eps,
924 affine: self.affine,
925 });
926 Tensor::from_operation(TensorStorage::gpu(handle), shape.to_vec(), grad_fn)
927 } else {
928 Tensor::from_storage(TensorStorage::gpu(handle), shape.to_vec(), false)
929 };
930 }
931 return Err(FerrotorchError::NotImplementedOnCuda {
933 op: "GroupNorm::forward",
934 });
935 }
936 let input_data = input.data()?;
937 let weight_data = self.weight.tensor().data()?;
938 let bias_data = self.bias.tensor().data()?;
939 let eps_t = T::from(self.eps).unwrap();
940 let group_n = T::from(group_size).unwrap();
941
942 let mut output = vec![zero::<T>(); input.numel()];
943
944 for b in 0..batch_size {
945 for g in 0..self.num_groups {
946 let c_start = g * channels_per_group;
947 let c_end = c_start + channels_per_group;
948
949 let mut sum = zero::<T>();
951 for c in c_start..c_end {
952 for s in 0..spatial {
953 let idx = b * channels * spatial + c * spatial + s;
954 sum += input_data[idx];
955 }
956 }
957 let mean = sum / group_n;
958
959 let mut var_sum = zero::<T>();
961 for c in c_start..c_end {
962 for s in 0..spatial {
963 let idx = b * channels * spatial + c * spatial + s;
964 let d = input_data[idx] - mean;
965 var_sum += d * d;
966 }
967 }
968 let var = var_sum / group_n;
969 let inv_std = (var + eps_t).sqrt().recip();
970
971 for c in c_start..c_end {
973 for s in 0..spatial {
974 let idx = b * channels * spatial + c * spatial + s;
975 let normed = (input_data[idx] - mean) * inv_std;
976 if self.affine {
977 output[idx] = normed * weight_data[c] + bias_data[c];
978 } else {
979 output[idx] = normed;
980 }
981 }
982 }
983 }
984 }
985
986 let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
987
988 if is_grad_enabled() && input.requires_grad() {
989 let grad_fn = Arc::new(GroupNormBackward {
990 input: input.clone(),
991 weight: self.weight.tensor().clone(),
992 bias: self.bias.tensor().clone(),
993 num_groups: self.num_groups,
994 num_channels: self.num_channels,
995 eps: self.eps,
996 affine: self.affine,
997 });
998 Tensor::from_operation(
999 TensorStorage::cpu(result.data()?.to_vec()),
1000 result.shape().to_vec(),
1001 grad_fn,
1002 )
1003 } else {
1004 Ok(result)
1005 }
1006 }
1007
1008 fn parameters(&self) -> Vec<&Parameter<T>> {
1009 if self.affine {
1010 vec![&self.weight, &self.bias]
1011 } else {
1012 vec![]
1013 }
1014 }
1015
1016 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1017 if self.affine {
1018 vec![&mut self.weight, &mut self.bias]
1019 } else {
1020 vec![]
1021 }
1022 }
1023
1024 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1025 if self.affine {
1026 vec![
1027 ("weight".to_string(), &self.weight),
1028 ("bias".to_string(), &self.bias),
1029 ]
1030 } else {
1031 vec![]
1032 }
1033 }
1034
1035 fn train(&mut self) {
1036 self.training = true;
1037 }
1038
1039 fn eval(&mut self) {
1040 self.training = false;
1041 }
1042
1043 fn is_training(&self) -> bool {
1044 self.training
1045 }
1046}
1047
1048#[derive(Debug)]
1057struct GroupNormBackward<T: Float> {
1058 input: Tensor<T>,
1059 weight: Tensor<T>,
1060 bias: Tensor<T>,
1061 num_groups: usize,
1062 num_channels: usize,
1063 eps: f64,
1064 affine: bool,
1065}
1066
1067impl<T: Float> GradFn<T> for GroupNormBackward<T> {
1068 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1069 let shape = self.input.shape();
1070 let batch_size = shape[0];
1071 let channels = shape[1];
1072 let channels_per_group = channels / self.num_groups;
1073 let spatial_size: usize = shape[2..].iter().product();
1074 let spatial = spatial_size.max(1);
1075 let group_size = channels_per_group * spatial;
1076 let group_n = T::from(group_size).unwrap();
1077 let eps_t = T::from(self.eps).unwrap();
1078
1079 if self.input.is_cuda() {
1080 return Err(FerrotorchError::NotImplementedOnCuda {
1081 op: "GroupNormBackward",
1082 });
1083 }
1084 let input_data = self.input.data()?;
1085 let go_data = grad_output.data()?;
1086 let weight_data = self.weight.data()?;
1087
1088 let mut grad_input = vec![zero::<T>(); self.input.numel()];
1089 let mut grad_weight = vec![zero::<T>(); self.num_channels];
1090 let mut grad_bias = vec![zero::<T>(); self.num_channels];
1091
1092 for b in 0..batch_size {
1093 for g in 0..self.num_groups {
1094 let c_start = g * channels_per_group;
1095 let c_end = c_start + channels_per_group;
1096
1097 let mut sum = zero::<T>();
1099 for c in c_start..c_end {
1100 for s in 0..spatial {
1101 let idx = b * channels * spatial + c * spatial + s;
1102 sum += input_data[idx];
1103 }
1104 }
1105 let mean = sum / group_n;
1106
1107 let mut var_sum = zero::<T>();
1108 for c in c_start..c_end {
1109 for s in 0..spatial {
1110 let idx = b * channels * spatial + c * spatial + s;
1111 let d = input_data[idx] - mean;
1112 var_sum += d * d;
1113 }
1114 }
1115 let var = var_sum / group_n;
1116 let inv_std = (var + eps_t).sqrt().recip();
1117
1118 let mut dl_dx_hat_sum = zero::<T>();
1120 let mut dl_dx_hat_x_hat_sum = zero::<T>();
1121
1122 for c in c_start..c_end {
1123 for s in 0..spatial {
1124 let idx = b * channels * spatial + c * spatial + s;
1125 let x_hat = (input_data[idx] - mean) * inv_std;
1126 let dl_dx_hat = if self.affine {
1127 go_data[idx] * weight_data[c]
1128 } else {
1129 go_data[idx]
1130 };
1131 dl_dx_hat_sum += dl_dx_hat;
1132 dl_dx_hat_x_hat_sum += dl_dx_hat * x_hat;
1133
1134 if self.affine {
1135 grad_weight[c] += go_data[idx] * x_hat;
1136 grad_bias[c] += go_data[idx];
1137 }
1138 }
1139 }
1140
1141 let dl_dx_hat_mean = dl_dx_hat_sum / group_n;
1142 let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / group_n;
1143
1144 for (ci, &wd) in weight_data[c_start..c_end].iter().enumerate() {
1145 let c = c_start + ci;
1146 for s in 0..spatial {
1147 let idx = b * channels * spatial + c * spatial + s;
1148 let x_hat = (input_data[idx] - mean) * inv_std;
1149 let dl_dx_hat = if self.affine {
1150 go_data[idx] * wd
1151 } else {
1152 go_data[idx]
1153 };
1154 grad_input[idx] =
1155 inv_std * (dl_dx_hat - dl_dx_hat_mean - x_hat * dl_dx_hat_x_hat_mean);
1156 }
1157 }
1158 }
1159 }
1160
1161 let grad_input_tensor = Tensor::from_storage(
1162 TensorStorage::cpu(grad_input),
1163 self.input.shape().to_vec(),
1164 false,
1165 )?;
1166
1167 let grad_weight_out = if self.affine && self.weight.requires_grad() {
1168 Some(Tensor::from_storage(
1169 TensorStorage::cpu(grad_weight),
1170 vec![self.num_channels],
1171 false,
1172 )?)
1173 } else {
1174 None
1175 };
1176
1177 let grad_bias_out = if self.affine && self.bias.requires_grad() {
1178 Some(Tensor::from_storage(
1179 TensorStorage::cpu(grad_bias),
1180 vec![self.num_channels],
1181 false,
1182 )?)
1183 } else {
1184 None
1185 };
1186
1187 Ok(vec![
1188 Some(grad_input_tensor),
1189 grad_weight_out,
1190 grad_bias_out,
1191 ])
1192 }
1193
1194 fn inputs(&self) -> Vec<&Tensor<T>> {
1195 vec![&self.input, &self.weight, &self.bias]
1196 }
1197
1198 fn name(&self) -> &'static str {
1199 "GroupNormBackward"
1200 }
1201}
1202
1203#[derive(Debug)]
1222pub struct RMSNorm<T: Float> {
1223 pub normalized_shape: Vec<usize>,
1225 pub eps: f64,
1227 pub weight: Parameter<T>,
1229 training: bool,
1230}
1231
1232impl<T: Float> RMSNorm<T> {
1233 pub fn new(normalized_shape: Vec<usize>, eps: f64) -> FerrotorchResult<Self> {
1240 if normalized_shape.is_empty() {
1241 return Err(FerrotorchError::InvalidArgument {
1242 message: "normalized_shape must not be empty".into(),
1243 });
1244 }
1245
1246 let weight = Parameter::ones(&normalized_shape)?;
1247
1248 Ok(Self {
1249 normalized_shape,
1250 eps,
1251 weight,
1252 training: true,
1253 })
1254 }
1255}
1256
1257impl<T: Float> Module<T> for RMSNorm<T> {
1258 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1259 let shape = input.shape().to_vec();
1260 let ndim = shape.len();
1261 let norm_ndim = self.normalized_shape.len();
1262
1263 if ndim < norm_ndim {
1264 return Err(FerrotorchError::ShapeMismatch {
1265 message: format!(
1266 "RMSNorm: input has {} dims but normalized_shape has {} dims",
1267 ndim, norm_ndim
1268 ),
1269 });
1270 }
1271
1272 let last_dims = &shape[ndim - norm_ndim..];
1273 if last_dims != self.normalized_shape.as_slice() {
1274 return Err(FerrotorchError::ShapeMismatch {
1275 message: format!(
1276 "RMSNorm: input last dims {:?} don't match normalized_shape {:?}",
1277 last_dims, self.normalized_shape
1278 ),
1279 });
1280 }
1281
1282 let norm_size: usize = self.normalized_shape.iter().product();
1283 let batch_size = input.numel() / norm_size;
1284
1285 if input.is_cuda() {
1287 if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
1288 let eps_f32 = self.eps as f32;
1289 let handle = backend.rmsnorm_f32(
1290 input.gpu_handle()?,
1291 self.weight.tensor().gpu_handle()?,
1292 batch_size,
1293 norm_size,
1294 eps_f32,
1295 )?;
1296 return if is_grad_enabled() && input.requires_grad() {
1297 let grad_fn = Arc::new(RMSNormBackward {
1298 input: input.clone(),
1299 weight: self.weight.tensor().clone(),
1300 normalized_shape: self.normalized_shape.clone(),
1301 eps: self.eps,
1302 });
1303 Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
1304 } else {
1305 Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
1306 };
1307 }
1308 }
1309
1310 if input.is_cuda() {
1312 return Err(FerrotorchError::NotImplementedOnCuda {
1313 op: "RMSNorm::forward",
1314 });
1315 }
1316 let input_data = input.data()?;
1317 let weight_data = self.weight.tensor().data()?;
1318 let eps_t = T::from(self.eps).unwrap();
1319 let n_t = T::from(norm_size).unwrap();
1320
1321 let is_bf16 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<half::bf16>();
1326 let mut output = Vec::with_capacity(input.numel());
1327
1328 for b in 0..batch_size {
1329 let start = b * norm_size;
1330 let end = start + norm_size;
1331 let slice = &input_data[start..end];
1332
1333 if is_bf16 {
1334 let eps_f32 = self.eps as f32;
1336 let n_f32 = norm_size as f32;
1337 let mut sum_sq = 0.0f32;
1338 for &x in slice {
1339 let xf = x.to_f32().unwrap();
1340 sum_sq += xf * xf;
1341 }
1342 let inv_rms_f32 = 1.0f32 / ((sum_sq / n_f32) + eps_f32).sqrt();
1343 let inv_rms = T::from(inv_rms_f32).unwrap();
1344 for (i, &x) in slice.iter().enumerate() {
1345 output.push(x * inv_rms * weight_data[i]);
1346 }
1347 } else {
1348 let mean_sq = slice.iter().copied().fold(zero::<T>(), |a, x| a + x * x) / n_t;
1350 let rms = (mean_sq + eps_t).sqrt();
1351 let inv_rms = rms.recip();
1352
1353 for (i, &x) in slice.iter().enumerate() {
1354 output.push(x * inv_rms * weight_data[i]);
1355 }
1356 }
1357 }
1358
1359 let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
1360
1361 if is_grad_enabled() && input.requires_grad() {
1362 let grad_fn = Arc::new(RMSNormBackward {
1363 input: input.clone(),
1364 weight: self.weight.tensor().clone(),
1365 normalized_shape: self.normalized_shape.clone(),
1366 eps: self.eps,
1367 });
1368 Tensor::from_operation(
1369 TensorStorage::cpu(result.data()?.to_vec()),
1370 result.shape().to_vec(),
1371 grad_fn,
1372 )
1373 } else {
1374 Ok(result)
1375 }
1376 }
1377
1378 fn parameters(&self) -> Vec<&Parameter<T>> {
1379 vec![&self.weight]
1380 }
1381
1382 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1383 vec![&mut self.weight]
1384 }
1385
1386 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1387 vec![("weight".to_string(), &self.weight)]
1388 }
1389
1390 fn train(&mut self) {
1391 self.training = true;
1392 }
1393
1394 fn eval(&mut self) {
1395 self.training = false;
1396 }
1397
1398 fn is_training(&self) -> bool {
1399 self.training
1400 }
1401}
1402
1403#[derive(Debug)]
1425struct RMSNormBackward<T: Float> {
1426 input: Tensor<T>,
1427 weight: Tensor<T>,
1428 normalized_shape: Vec<usize>,
1429 eps: f64,
1430}
1431
1432impl<T: Float> GradFn<T> for RMSNormBackward<T> {
1433 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1434 let norm_size: usize = self.normalized_shape.iter().product();
1435 let batch_size = self.input.numel() / norm_size;
1436
1437 if self.input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
1439 if let Some(backend) = gpu_backend() {
1440 let (gi_h, gw_h) = if is_f64::<T>() {
1441 backend.rmsnorm_backward_f64(
1442 self.input.gpu_handle()?,
1443 grad_output.gpu_handle()?,
1444 self.weight.gpu_handle()?,
1445 batch_size,
1446 norm_size,
1447 self.eps,
1448 )?
1449 } else {
1450 backend.rmsnorm_backward_f32(
1451 self.input.gpu_handle()?,
1452 grad_output.gpu_handle()?,
1453 self.weight.gpu_handle()?,
1454 batch_size,
1455 norm_size,
1456 self.eps as f32,
1457 )?
1458 };
1459
1460 let grad_input_tensor = Tensor::from_storage(
1461 TensorStorage::gpu(gi_h),
1462 self.input.shape().to_vec(),
1463 false,
1464 )?;
1465
1466 let grad_weight_out = if self.weight.requires_grad() {
1467 Some(Tensor::from_storage(
1468 TensorStorage::gpu(gw_h),
1469 self.normalized_shape.clone(),
1470 false,
1471 )?)
1472 } else {
1473 None
1474 };
1475
1476 return Ok(vec![Some(grad_input_tensor), grad_weight_out]);
1477 }
1478 }
1479
1480 if self.input.is_cuda() {
1482 return Err(FerrotorchError::NotImplementedOnCuda {
1483 op: "RMSNormBackward",
1484 });
1485 }
1486 let n_t = T::from(norm_size).unwrap();
1487 let eps_t = T::from(self.eps).unwrap();
1488
1489 let input_data = self.input.data()?;
1490 let go_data = grad_output.data()?;
1491 let weight_data = self.weight.data()?;
1492
1493 let mut grad_input = vec![zero::<T>(); self.input.numel()];
1494 let mut grad_weight = vec![zero::<T>(); norm_size];
1495
1496 for b in 0..batch_size {
1497 let start = b * norm_size;
1498 let end = start + norm_size;
1499 let x_slice = &input_data[start..end];
1500 let go_slice = &go_data[start..end];
1501
1502 let mean_sq = x_slice.iter().copied().fold(zero::<T>(), |a, x| a + x * x) / n_t;
1504 let rms = (mean_sq + eps_t).sqrt();
1505 let inv_rms = rms.recip();
1506 let inv_rms_sq = inv_rms * inv_rms;
1507
1508 let go_x_w_mean = x_slice
1510 .iter()
1511 .zip(go_slice.iter())
1512 .zip(weight_data.iter())
1513 .fold(zero::<T>(), |a, ((&x, &go), &w)| a + go * x * w)
1514 / n_t;
1515
1516 for i in 0..norm_size {
1517 grad_input[start + i] = inv_rms
1519 * (go_slice[i] * weight_data[i] - x_slice[i] * inv_rms_sq * go_x_w_mean);
1520
1521 grad_weight[i] += go_slice[i] * x_slice[i] * inv_rms;
1523 }
1524 }
1525
1526 let grad_input_tensor = Tensor::from_storage(
1527 TensorStorage::cpu(grad_input),
1528 self.input.shape().to_vec(),
1529 false,
1530 )?;
1531
1532 let grad_weight_out = if self.weight.requires_grad() {
1533 Some(Tensor::from_storage(
1534 TensorStorage::cpu(grad_weight),
1535 self.normalized_shape.clone(),
1536 false,
1537 )?)
1538 } else {
1539 None
1540 };
1541
1542 Ok(vec![Some(grad_input_tensor), grad_weight_out])
1543 }
1544
1545 fn inputs(&self) -> Vec<&Tensor<T>> {
1546 vec![&self.input, &self.weight]
1547 }
1548
1549 fn name(&self) -> &'static str {
1550 "RMSNormBackward"
1551 }
1552}
1553
1554pub struct BatchNorm2d<T: Float> {
1577 pub num_features: usize,
1579 pub eps: f64,
1581 pub momentum: f64,
1584 pub affine: bool,
1586 pub weight: Option<Parameter<T>>,
1588 pub bias: Option<Parameter<T>>,
1590 running_mean: Mutex<Vec<f64>>,
1594 running_var: Mutex<Vec<f64>>,
1596 num_batches_tracked: Mutex<usize>,
1598 training: Mutex<bool>,
1600}
1601
1602impl<T: Float> std::fmt::Debug for BatchNorm2d<T> {
1604 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1605 f.debug_struct("BatchNorm2d")
1606 .field("num_features", &self.num_features)
1607 .field("eps", &self.eps)
1608 .field("momentum", &self.momentum)
1609 .field("affine", &self.affine)
1610 .field("weight", &self.weight)
1611 .field("bias", &self.bias)
1612 .field("training", &self.training)
1613 .finish()
1614 }
1615}
1616
1617impl<T: Float> BatchNorm2d<T> {
1618 pub fn new(
1627 num_features: usize,
1628 eps: f64,
1629 momentum: f64,
1630 affine: bool,
1631 ) -> FerrotorchResult<Self> {
1632 if num_features == 0 {
1633 return Err(FerrotorchError::InvalidArgument {
1634 message: "num_features must be positive".into(),
1635 });
1636 }
1637
1638 let weight = if affine {
1639 Some(Parameter::ones(&[num_features])?)
1640 } else {
1641 None
1642 };
1643
1644 let bias = if affine {
1645 Some(Parameter::zeros(&[num_features])?)
1646 } else {
1647 None
1648 };
1649
1650 Ok(Self {
1651 num_features,
1652 eps,
1653 momentum,
1654 affine,
1655 weight,
1656 bias,
1657 running_mean: Mutex::new(vec![0.0; num_features]),
1658 running_var: Mutex::new(vec![1.0; num_features]),
1659 num_batches_tracked: Mutex::new(0),
1660 training: Mutex::new(true),
1661 })
1662 }
1663
1664 pub fn running_mean(&self) -> Vec<f64> {
1666 self.running_mean.lock().unwrap().clone()
1667 }
1668
1669 pub fn running_var(&self) -> Vec<f64> {
1671 self.running_var.lock().unwrap().clone()
1672 }
1673
1674 pub fn num_batches_tracked(&self) -> usize {
1676 *self.num_batches_tracked.lock().unwrap()
1677 }
1678
1679 pub fn set_running_mean(&self, value: &[T]) -> FerrotorchResult<()> {
1698 if value.len() != self.num_features {
1699 return Err(FerrotorchError::ShapeMismatch {
1700 message: format!(
1701 "BatchNorm2d::set_running_mean: expected slice of length \
1702 num_features={}, got {}",
1703 self.num_features,
1704 value.len()
1705 ),
1706 });
1707 }
1708 for (i, x) in value.iter().enumerate() {
1709 if !num_traits::Float::is_finite(*x) {
1710 return Err(FerrotorchError::InvalidArgument {
1711 message: format!(
1712 "BatchNorm2d::set_running_mean: non-finite value at \
1713 index {i} (running_mean must be finite)"
1714 ),
1715 });
1716 }
1717 }
1718 let mut rm = self.running_mean.lock().unwrap();
1719 for (slot, x) in rm.iter_mut().zip(value.iter()) {
1720 *slot = x.to_f64().unwrap();
1721 }
1722 Ok(())
1723 }
1724
1725 pub fn set_running_var(&self, value: &[T]) -> FerrotorchResult<()> {
1741 if value.len() != self.num_features {
1742 return Err(FerrotorchError::ShapeMismatch {
1743 message: format!(
1744 "BatchNorm2d::set_running_var: expected slice of length \
1745 num_features={}, got {}",
1746 self.num_features,
1747 value.len()
1748 ),
1749 });
1750 }
1751 let zero_t = zero::<T>();
1752 for (i, x) in value.iter().enumerate() {
1753 if !num_traits::Float::is_finite(*x) {
1754 return Err(FerrotorchError::InvalidArgument {
1755 message: format!(
1756 "BatchNorm2d::set_running_var: non-finite value at \
1757 index {i} (running_var must be finite)"
1758 ),
1759 });
1760 }
1761 if *x < zero_t {
1762 return Err(FerrotorchError::InvalidArgument {
1763 message: format!(
1764 "BatchNorm2d::set_running_var: negative value {} at \
1765 index {i} (running_var must be non-negative)",
1766 x.to_f64().unwrap()
1767 ),
1768 });
1769 }
1770 }
1771 let mut rv = self.running_var.lock().unwrap();
1772 for (slot, x) in rv.iter_mut().zip(value.iter()) {
1773 *slot = x.to_f64().unwrap();
1774 }
1775 Ok(())
1776 }
1777
1778 pub fn set_num_batches_tracked(&self, value: usize) -> FerrotorchResult<()> {
1785 let mut nbt = self.num_batches_tracked.lock().unwrap();
1786 *nbt = value;
1787 Ok(())
1788 }
1789}
1790
1791impl<T: Float> Module<T> for BatchNorm2d<T> {
1792 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1793 let shape = input.shape().to_vec();
1794 if shape.len() != 4 {
1795 return Err(FerrotorchError::ShapeMismatch {
1796 message: format!(
1797 "BatchNorm2d: expected 4D input [B, C, H, W], got {:?}",
1798 shape
1799 ),
1800 });
1801 }
1802
1803 let batch = shape[0];
1804 let channels = shape[1];
1805 let height = shape[2];
1806 let width = shape[3];
1807 let spatial = height * width;
1808
1809 if channels != self.num_features {
1810 return Err(FerrotorchError::ShapeMismatch {
1811 message: format!(
1812 "BatchNorm2d: expected {} channels, got {}",
1813 self.num_features, channels
1814 ),
1815 });
1816 }
1817
1818 if *self.training.lock().unwrap() && batch * spatial <= 1 {
1824 return Err(FerrotorchError::InvalidArgument {
1825 message: format!(
1826 "Expected more than 1 value per channel when training, got input size {:?}",
1827 shape
1828 ),
1829 });
1830 }
1831
1832 if input.is_cuda() {
1835 if is_f32::<T>() {
1836 let is_training = *self.training.lock().unwrap();
1837 if let Some(handle) = batch_norm_gpu_forward(
1838 input,
1839 self.weight.as_ref().map(|w| w.tensor()),
1840 self.bias.as_ref().map(|b| b.tensor()),
1841 &self.running_mean,
1842 &self.running_var,
1843 &self.num_batches_tracked,
1844 self.momentum,
1845 self.eps,
1846 channels,
1847 spatial,
1848 is_training,
1849 )? {
1850 return if is_grad_enabled() && input.requires_grad() {
1851 let grad_fn = Arc::new(BatchNorm2dBackward {
1852 input: input.clone(),
1853 x_hat: Tensor::from_storage(
1854 TensorStorage::cpu(Vec::new()),
1855 vec![0],
1856 false,
1857 )?,
1858 weight: self.weight.as_ref().map(|w| w.tensor().clone()),
1859 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
1860 chan_var: Vec::new(),
1861 eps: self.eps,
1862 affine: self.affine,
1863 is_training,
1864 running_mean: self.running_mean.lock().unwrap().clone(),
1865 running_var: self.running_var.lock().unwrap().clone(),
1866 });
1867 Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
1868 } else {
1869 Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
1870 };
1871 }
1872 }
1873 return Err(FerrotorchError::NotImplementedOnCuda {
1874 op: "BatchNorm2d::forward",
1875 });
1876 }
1877 let input_data = input.data()?;
1878 let eps_t = T::from(self.eps).unwrap();
1879
1880 let weight_data = self.weight.as_ref().map(|w| w.tensor().data().unwrap());
1881 let bias_data = self.bias.as_ref().map(|b| b.tensor().data().unwrap());
1882
1883 let is_training = *self.training.lock().unwrap();
1884
1885 let mut chan_mean = vec![zero::<T>(); channels];
1887 let mut chan_var = vec![zero::<T>(); channels];
1888
1889 if is_training {
1890 let count = batch * spatial;
1892 let count_t = T::from(count).unwrap();
1893
1894 for c in 0..channels {
1895 let mut sum = zero::<T>();
1896 for b in 0..batch {
1897 let base = b * channels * spatial + c * spatial;
1898 for s in 0..spatial {
1899 sum += input_data[base + s];
1900 }
1901 }
1902 chan_mean[c] = sum / count_t;
1903
1904 let mut var_sum = zero::<T>();
1905 for b in 0..batch {
1906 let base = b * channels * spatial + c * spatial;
1907 for s in 0..spatial {
1908 let d = input_data[base + s] - chan_mean[c];
1909 var_sum += d * d;
1910 }
1911 }
1912 chan_var[c] = var_sum / count_t;
1914 }
1915
1916 {
1918 let mut rm = self.running_mean.lock().unwrap();
1919 let mut rv = self.running_var.lock().unwrap();
1920 let mut nbt = self.num_batches_tracked.lock().unwrap();
1921 *nbt += 1;
1922
1923 let mom = self.momentum;
1924 let bessel = if count > 1 {
1927 count as f64 / (count as f64 - 1.0)
1928 } else {
1929 1.0
1930 };
1931
1932 for c in 0..channels {
1933 let batch_mean_f64 = chan_mean[c].to_f64().unwrap();
1934 let batch_var_f64 = chan_var[c].to_f64().unwrap();
1935
1936 rm[c] = (1.0 - mom) * rm[c] + mom * batch_mean_f64;
1937 rv[c] = (1.0 - mom) * rv[c] + mom * batch_var_f64 * bessel;
1938 }
1939 }
1940 } else {
1941 let rm = self.running_mean.lock().unwrap();
1943 let rv = self.running_var.lock().unwrap();
1944
1945 for c in 0..channels {
1946 chan_mean[c] = T::from(rm[c]).unwrap();
1947 chan_var[c] = T::from(rv[c]).unwrap();
1948 }
1949 }
1950
1951 let mut output = vec![zero::<T>(); input.numel()];
1953
1954 let mut inv_std = vec![zero::<T>(); channels];
1956 let mut x_hat_data = if is_grad_enabled() && input.requires_grad() {
1958 Vec::with_capacity(input.numel())
1959 } else {
1960 Vec::new()
1961 };
1962 let need_x_hat = is_grad_enabled() && input.requires_grad();
1963
1964 for c in 0..channels {
1965 inv_std[c] = (chan_var[c] + eps_t).sqrt().recip();
1966 }
1967
1968 for b in 0..batch {
1969 for c in 0..channels {
1970 let base = b * channels * spatial + c * spatial;
1971 for s in 0..spatial {
1972 let idx = base + s;
1973 let normed = (input_data[idx] - chan_mean[c]) * inv_std[c];
1974
1975 if need_x_hat {
1976 x_hat_data.push(normed);
1977 }
1978
1979 if self.affine {
1980 let w = weight_data.as_ref().unwrap();
1981 let bi = bias_data.as_ref().unwrap();
1982 output[idx] = normed * w[c] + bi[c];
1983 } else {
1984 output[idx] = normed;
1985 }
1986 }
1987 }
1988 }
1989
1990 let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
1991
1992 if is_grad_enabled() && input.requires_grad() {
1993 let weight_tensor = self.weight.as_ref().map(|w| w.tensor().clone());
1994 let bias_tensor = self.bias.as_ref().map(|b| b.tensor().clone());
1995
1996 let grad_fn = Arc::new(BatchNorm2dBackward {
1997 input: input.clone(),
1998 x_hat: Tensor::from_storage(TensorStorage::cpu(x_hat_data), shape.to_vec(), false)?,
1999 weight: weight_tensor,
2000 bias: bias_tensor,
2001 chan_var: chan_var.iter().map(|v| v.to_f64().unwrap()).collect(),
2002 eps: self.eps,
2003 affine: self.affine,
2004 is_training,
2005 running_mean: self.running_mean.lock().unwrap().clone(),
2006 running_var: self.running_var.lock().unwrap().clone(),
2007 });
2008
2009 Tensor::from_operation(
2010 TensorStorage::cpu(result.data()?.to_vec()),
2011 result.shape().to_vec(),
2012 grad_fn,
2013 )
2014 } else {
2015 Ok(result)
2016 }
2017 }
2018
2019 fn parameters(&self) -> Vec<&Parameter<T>> {
2020 match (&self.weight, &self.bias) {
2021 (Some(w), Some(b)) => vec![w, b],
2022 _ => vec![],
2023 }
2024 }
2025
2026 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2027 match (&mut self.weight, &mut self.bias) {
2028 (Some(w), Some(b)) => vec![w, b],
2029 _ => vec![],
2030 }
2031 }
2032
2033 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2034 match (&self.weight, &self.bias) {
2035 (Some(w), Some(b)) => vec![("weight".to_string(), w), ("bias".to_string(), b)],
2036 _ => vec![],
2037 }
2038 }
2039
2040 fn train(&mut self) {
2041 *self.training.lock().unwrap() = true;
2042 }
2043
2044 fn eval(&mut self) {
2045 *self.training.lock().unwrap() = false;
2046 }
2047
2048 fn is_training(&self) -> bool {
2049 *self.training.lock().unwrap()
2050 }
2051
2052 fn as_any(&self) -> Option<&dyn std::any::Any> {
2058 Some(self)
2059 }
2060}
2061
2062#[derive(Debug)]
2090struct BatchNorm2dBackward<T: Float> {
2091 input: Tensor<T>,
2092 x_hat: Tensor<T>,
2095 weight: Option<Tensor<T>>,
2096 bias: Option<Tensor<T>>,
2097 chan_var: Vec<f64>,
2099 eps: f64,
2100 affine: bool,
2101 is_training: bool,
2104 running_mean: Vec<f64>,
2106 running_var: Vec<f64>,
2108}
2109
2110impl<T: Float> GradFn<T> for BatchNorm2dBackward<T> {
2111 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2112 let shape = self.input.shape();
2113 let batch = shape[0];
2114 let channels = shape[1];
2115 let height = shape[2];
2116 let width = shape[3];
2117 let spatial = height * width;
2118 let count = batch * spatial;
2119 let count_t = T::from(count).unwrap();
2120
2121 if self.input.is_cuda() && is_f32::<T>() {
2125 let weight_dev;
2126 let weight_buf = match self.weight.as_ref() {
2127 Some(w) => w,
2128 None => {
2129 weight_dev = ferrotorch_core::creation::ones::<T>(&[channels])?
2130 .to(self.input.device())?;
2131 &weight_dev
2132 }
2133 };
2134 if let Some(grads) = batch_norm_gpu_backward(
2135 &self.input,
2136 grad_output,
2137 weight_buf,
2138 &self.running_mean,
2139 &self.running_var,
2140 batch,
2141 channels,
2142 spatial,
2143 self.eps,
2144 self.is_training,
2145 self.affine,
2146 self.weight.as_ref().is_some_and(|w| w.requires_grad()),
2147 self.bias.as_ref().is_some_and(|b| b.requires_grad()),
2148 )? {
2149 return Ok(grads);
2150 }
2151 return Err(FerrotorchError::NotImplementedOnCuda {
2152 op: "BatchNorm2dBackward",
2153 });
2154 }
2155 if self.input.is_cuda() {
2156 return Err(FerrotorchError::NotImplementedOnCuda {
2157 op: "BatchNorm2dBackward",
2158 });
2159 }
2160 let go_data = grad_output.data()?;
2161 let x_hat_data = self.x_hat.data()?;
2162
2163 let weight_data = self.weight.as_ref().map(|w| w.data().unwrap().to_vec());
2164
2165 let mut grad_input = vec![zero::<T>(); self.input.numel()];
2166 let mut grad_weight = vec![zero::<T>(); channels];
2167 let mut grad_bias = vec![zero::<T>(); channels];
2168
2169 for c in 0..channels {
2170 let var_f64 = self.chan_var[c];
2171 let inv_std = T::from(1.0 / (var_f64 + self.eps).sqrt()).unwrap();
2172
2173 let mut dl_dx_hat_sum = zero::<T>();
2175 let mut dl_dx_hat_x_hat_sum = zero::<T>();
2176
2177 for b in 0..batch {
2178 let base = b * channels * spatial + c * spatial;
2179 for s in 0..spatial {
2180 let idx = base + s;
2181 let x_h = x_hat_data[idx];
2182 let go = go_data[idx];
2183
2184 let dl_dx_hat = if self.affine {
2185 go * weight_data.as_ref().unwrap()[c]
2186 } else {
2187 go
2188 };
2189
2190 dl_dx_hat_sum += dl_dx_hat;
2191 dl_dx_hat_x_hat_sum += dl_dx_hat * x_h;
2192
2193 if self.affine {
2194 grad_weight[c] += go * x_h;
2195 grad_bias[c] += go;
2196 }
2197 }
2198 }
2199
2200 let dl_dx_hat_mean = dl_dx_hat_sum / count_t;
2201 let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / count_t;
2202
2203 for b in 0..batch {
2205 let base = b * channels * spatial + c * spatial;
2206 for s in 0..spatial {
2207 let idx = base + s;
2208 let x_h = x_hat_data[idx];
2209 let go = go_data[idx];
2210
2211 let dl_dx_hat = if self.affine {
2212 go * weight_data.as_ref().unwrap()[c]
2213 } else {
2214 go
2215 };
2216
2217 grad_input[idx] =
2218 inv_std * (dl_dx_hat - dl_dx_hat_mean - x_h * dl_dx_hat_x_hat_mean);
2219 }
2220 }
2221 }
2222
2223 let grad_input_tensor = Tensor::from_storage(
2224 TensorStorage::cpu(grad_input),
2225 self.input.shape().to_vec(),
2226 false,
2227 )?;
2228
2229 let grad_weight_out = if self.affine {
2230 if let Some(ref w) = self.weight {
2231 if w.requires_grad() {
2232 Some(Tensor::from_storage(
2233 TensorStorage::cpu(grad_weight),
2234 vec![channels],
2235 false,
2236 )?)
2237 } else {
2238 None
2239 }
2240 } else {
2241 None
2242 }
2243 } else {
2244 None
2245 };
2246
2247 let grad_bias_out = if self.affine {
2248 if let Some(ref b) = self.bias {
2249 if b.requires_grad() {
2250 Some(Tensor::from_storage(
2251 TensorStorage::cpu(grad_bias),
2252 vec![channels],
2253 false,
2254 )?)
2255 } else {
2256 None
2257 }
2258 } else {
2259 None
2260 }
2261 } else {
2262 None
2263 };
2264
2265 if self.affine {
2271 Ok(vec![
2272 Some(grad_input_tensor),
2273 grad_weight_out,
2274 grad_bias_out,
2275 ])
2276 } else {
2277 Ok(vec![Some(grad_input_tensor)])
2278 }
2279 }
2280
2281 fn inputs(&self) -> Vec<&Tensor<T>> {
2282 let mut v: Vec<&Tensor<T>> = vec![&self.input];
2283 if let Some(ref w) = self.weight {
2284 v.push(w);
2285 }
2286 if let Some(ref b) = self.bias {
2287 v.push(b);
2288 }
2289 v
2290 }
2291
2292 fn name(&self) -> &'static str {
2293 "BatchNorm2dBackward"
2294 }
2295}
2296
2297pub struct BatchNorm1d<T: Float> {
2318 pub num_features: usize,
2320 pub eps: f64,
2322 pub momentum: f64,
2325 pub affine: bool,
2327 pub weight: Option<Parameter<T>>,
2329 pub bias: Option<Parameter<T>>,
2331 running_mean: Mutex<Vec<f64>>,
2333 running_var: Mutex<Vec<f64>>,
2335 num_batches_tracked: Mutex<usize>,
2337 training: Mutex<bool>,
2339}
2340
2341impl<T: Float> std::fmt::Debug for BatchNorm1d<T> {
2342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2343 f.debug_struct("BatchNorm1d")
2344 .field("num_features", &self.num_features)
2345 .field("eps", &self.eps)
2346 .field("momentum", &self.momentum)
2347 .field("affine", &self.affine)
2348 .field("weight", &self.weight)
2349 .field("bias", &self.bias)
2350 .field("training", &self.training)
2351 .finish()
2352 }
2353}
2354
2355impl<T: Float> BatchNorm1d<T> {
2356 pub fn new(
2365 num_features: usize,
2366 eps: f64,
2367 momentum: f64,
2368 affine: bool,
2369 ) -> FerrotorchResult<Self> {
2370 if num_features == 0 {
2371 return Err(FerrotorchError::InvalidArgument {
2372 message: "BatchNorm1d: num_features must be positive".into(),
2373 });
2374 }
2375
2376 let weight = if affine {
2377 Some(Parameter::ones(&[num_features])?)
2378 } else {
2379 None
2380 };
2381
2382 let bias = if affine {
2383 Some(Parameter::zeros(&[num_features])?)
2384 } else {
2385 None
2386 };
2387
2388 Ok(Self {
2389 num_features,
2390 eps,
2391 momentum,
2392 affine,
2393 weight,
2394 bias,
2395 running_mean: Mutex::new(vec![0.0; num_features]),
2396 running_var: Mutex::new(vec![1.0; num_features]),
2397 num_batches_tracked: Mutex::new(0),
2398 training: Mutex::new(true),
2399 })
2400 }
2401
2402 pub fn running_mean(&self) -> Vec<f64> {
2404 self.running_mean.lock().unwrap().clone()
2405 }
2406
2407 pub fn running_var(&self) -> Vec<f64> {
2409 self.running_var.lock().unwrap().clone()
2410 }
2411
2412 pub fn num_batches_tracked(&self) -> usize {
2414 *self.num_batches_tracked.lock().unwrap()
2415 }
2416
2417 pub fn set_running_mean(&self, value: &[T]) -> FerrotorchResult<()> {
2425 if value.len() != self.num_features {
2426 return Err(FerrotorchError::ShapeMismatch {
2427 message: format!(
2428 "BatchNorm1d::set_running_mean: expected slice of length \
2429 num_features={}, got {}",
2430 self.num_features,
2431 value.len()
2432 ),
2433 });
2434 }
2435 for (i, x) in value.iter().enumerate() {
2436 if !num_traits::Float::is_finite(*x) {
2437 return Err(FerrotorchError::InvalidArgument {
2438 message: format!(
2439 "BatchNorm1d::set_running_mean: non-finite value at \
2440 index {i} (running_mean must be finite)"
2441 ),
2442 });
2443 }
2444 }
2445 let mut rm = self.running_mean.lock().unwrap();
2446 for (slot, x) in rm.iter_mut().zip(value.iter()) {
2447 *slot = x.to_f64().unwrap();
2448 }
2449 Ok(())
2450 }
2451
2452 pub fn set_running_var(&self, value: &[T]) -> FerrotorchResult<()> {
2460 if value.len() != self.num_features {
2461 return Err(FerrotorchError::ShapeMismatch {
2462 message: format!(
2463 "BatchNorm1d::set_running_var: expected slice of length \
2464 num_features={}, got {}",
2465 self.num_features,
2466 value.len()
2467 ),
2468 });
2469 }
2470 let zero_t = zero::<T>();
2471 for (i, x) in value.iter().enumerate() {
2472 if !num_traits::Float::is_finite(*x) {
2473 return Err(FerrotorchError::InvalidArgument {
2474 message: format!(
2475 "BatchNorm1d::set_running_var: non-finite value at \
2476 index {i} (running_var must be finite)"
2477 ),
2478 });
2479 }
2480 if *x < zero_t {
2481 return Err(FerrotorchError::InvalidArgument {
2482 message: format!(
2483 "BatchNorm1d::set_running_var: negative value {} at \
2484 index {i} (running_var must be non-negative)",
2485 x.to_f64().unwrap()
2486 ),
2487 });
2488 }
2489 }
2490 let mut rv = self.running_var.lock().unwrap();
2491 for (slot, x) in rv.iter_mut().zip(value.iter()) {
2492 *slot = x.to_f64().unwrap();
2493 }
2494 Ok(())
2495 }
2496
2497 pub fn set_num_batches_tracked(&self, value: usize) -> FerrotorchResult<()> {
2501 let mut nbt = self.num_batches_tracked.lock().unwrap();
2502 *nbt = value;
2503 Ok(())
2504 }
2505}
2506
2507impl<T: Float> Module<T> for BatchNorm1d<T> {
2508 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2509 let shape = input.shape().to_vec();
2510 let ndim = shape.len();
2511
2512 if ndim != 2 && ndim != 3 {
2514 return Err(FerrotorchError::ShapeMismatch {
2515 message: format!(
2516 "BatchNorm1d: expected 2D [N, C] or 3D [N, C, L] input, got {:?}",
2517 shape
2518 ),
2519 });
2520 }
2521
2522 let batch = shape[0];
2523 let channels = shape[1];
2524 let length = if ndim == 3 { shape[2] } else { 1 };
2525
2526 if channels != self.num_features {
2527 return Err(FerrotorchError::ShapeMismatch {
2528 message: format!(
2529 "BatchNorm1d: expected {} channels, got {}",
2530 self.num_features, channels
2531 ),
2532 });
2533 }
2534
2535 if batch == 0 {
2537 return Ok(input.clone());
2538 }
2539
2540 if *self.training.lock().unwrap() && batch * length <= 1 {
2546 return Err(FerrotorchError::InvalidArgument {
2547 message: format!(
2548 "Expected more than 1 value per channel when training, got input size {:?}",
2549 shape
2550 ),
2551 });
2552 }
2553
2554 if input.is_cuda() {
2556 if is_f32::<T>() {
2557 let is_training = *self.training.lock().unwrap();
2558 if let Some(handle) = batch_norm_gpu_forward(
2559 input,
2560 self.weight.as_ref().map(|w| w.tensor()),
2561 self.bias.as_ref().map(|b| b.tensor()),
2562 &self.running_mean,
2563 &self.running_var,
2564 &self.num_batches_tracked,
2565 self.momentum,
2566 self.eps,
2567 channels,
2568 length,
2569 is_training,
2570 )? {
2571 return if is_grad_enabled() && input.requires_grad() {
2572 let grad_fn = Arc::new(BatchNorm1dBackward {
2573 input: input.clone(),
2574 x_hat: Tensor::from_storage(
2575 TensorStorage::cpu(Vec::new()),
2576 vec![0],
2577 false,
2578 )?,
2579 weight: self.weight.as_ref().map(|w| w.tensor().clone()),
2580 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
2581 chan_var: Vec::new(),
2582 eps: self.eps,
2583 affine: self.affine,
2584 is_training,
2585 running_mean: self.running_mean.lock().unwrap().clone(),
2586 running_var: self.running_var.lock().unwrap().clone(),
2587 });
2588 Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
2589 } else {
2590 Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
2591 };
2592 }
2593 }
2594 return Err(FerrotorchError::NotImplementedOnCuda {
2595 op: "BatchNorm1d::forward",
2596 });
2597 }
2598 let input_data = input.data()?;
2599 let eps_t = T::from(self.eps).unwrap();
2600
2601 let weight_data = self.weight.as_ref().map(|w| w.tensor().data().unwrap());
2602 let bias_data = self.bias.as_ref().map(|b| b.tensor().data().unwrap());
2603
2604 let is_training = *self.training.lock().unwrap();
2605
2606 let mut chan_mean = vec![zero::<T>(); channels];
2607 let mut chan_var = vec![zero::<T>(); channels];
2608
2609 if is_training {
2610 let count = batch * length;
2611 let count_t = T::from(count).unwrap();
2612
2613 for c in 0..channels {
2614 let mut s = zero::<T>();
2615 for b in 0..batch {
2616 let base = b * channels * length + c * length;
2617 for l in 0..length {
2618 s += input_data[base + l];
2619 }
2620 }
2621 chan_mean[c] = s / count_t;
2622
2623 let mut var_sum = zero::<T>();
2624 for b in 0..batch {
2625 let base = b * channels * length + c * length;
2626 for l in 0..length {
2627 let d = input_data[base + l] - chan_mean[c];
2628 var_sum += d * d;
2629 }
2630 }
2631 chan_var[c] = var_sum / count_t;
2632 }
2633
2634 {
2636 let mut rm = self.running_mean.lock().unwrap();
2637 let mut rv = self.running_var.lock().unwrap();
2638 let mut nbt = self.num_batches_tracked.lock().unwrap();
2639 *nbt += 1;
2640
2641 let mom = self.momentum;
2642 let bessel = if count > 1 {
2643 count as f64 / (count as f64 - 1.0)
2644 } else {
2645 1.0
2646 };
2647
2648 for c in 0..channels {
2649 let batch_mean_f64 = chan_mean[c].to_f64().unwrap();
2650 let batch_var_f64 = chan_var[c].to_f64().unwrap();
2651
2652 rm[c] = (1.0 - mom) * rm[c] + mom * batch_mean_f64;
2653 rv[c] = (1.0 - mom) * rv[c] + mom * batch_var_f64 * bessel;
2654 }
2655 }
2656 } else {
2657 let rm = self.running_mean.lock().unwrap();
2658 let rv = self.running_var.lock().unwrap();
2659
2660 for c in 0..channels {
2661 chan_mean[c] = T::from(rm[c]).unwrap();
2662 chan_var[c] = T::from(rv[c]).unwrap();
2663 }
2664 }
2665
2666 let mut output = vec![zero::<T>(); input.numel()];
2667
2668 let mut inv_std = vec![zero::<T>(); channels];
2669 let need_x_hat = is_grad_enabled() && input.requires_grad();
2670 let mut x_hat_data = if need_x_hat {
2671 Vec::with_capacity(input.numel())
2672 } else {
2673 Vec::new()
2674 };
2675
2676 for c in 0..channels {
2677 inv_std[c] = (chan_var[c] + eps_t).sqrt().recip();
2678 }
2679
2680 for b in 0..batch {
2681 for c in 0..channels {
2682 let base = b * channels * length + c * length;
2683 for l in 0..length {
2684 let idx = base + l;
2685 let normed = (input_data[idx] - chan_mean[c]) * inv_std[c];
2686
2687 if need_x_hat {
2688 x_hat_data.push(normed);
2689 }
2690
2691 if self.affine {
2692 let w = weight_data.as_ref().unwrap();
2693 let bi = bias_data.as_ref().unwrap();
2694 output[idx] = normed * w[c] + bi[c];
2695 } else {
2696 output[idx] = normed;
2697 }
2698 }
2699 }
2700 }
2701
2702 let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
2703
2704 if is_grad_enabled() && input.requires_grad() {
2705 let weight_tensor = self.weight.as_ref().map(|w| w.tensor().clone());
2706 let bias_tensor = self.bias.as_ref().map(|b| b.tensor().clone());
2707
2708 let grad_fn = Arc::new(BatchNorm1dBackward {
2709 input: input.clone(),
2710 x_hat: Tensor::from_storage(TensorStorage::cpu(x_hat_data), shape.to_vec(), false)?,
2711 weight: weight_tensor,
2712 bias: bias_tensor,
2713 chan_var: chan_var.iter().map(|v| v.to_f64().unwrap()).collect(),
2714 eps: self.eps,
2715 affine: self.affine,
2716 is_training,
2717 running_mean: self.running_mean.lock().unwrap().clone(),
2718 running_var: self.running_var.lock().unwrap().clone(),
2719 });
2720
2721 Tensor::from_operation(
2722 TensorStorage::cpu(result.data()?.to_vec()),
2723 result.shape().to_vec(),
2724 grad_fn,
2725 )
2726 } else {
2727 Ok(result)
2728 }
2729 }
2730
2731 fn parameters(&self) -> Vec<&Parameter<T>> {
2732 match (&self.weight, &self.bias) {
2733 (Some(w), Some(b)) => vec![w, b],
2734 _ => vec![],
2735 }
2736 }
2737
2738 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2739 match (&mut self.weight, &mut self.bias) {
2740 (Some(w), Some(b)) => vec![w, b],
2741 _ => vec![],
2742 }
2743 }
2744
2745 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2746 match (&self.weight, &self.bias) {
2747 (Some(w), Some(b)) => vec![("weight".to_string(), w), ("bias".to_string(), b)],
2748 _ => vec![],
2749 }
2750 }
2751
2752 fn train(&mut self) {
2753 *self.training.lock().unwrap() = true;
2754 }
2755
2756 fn eval(&mut self) {
2757 *self.training.lock().unwrap() = false;
2758 }
2759
2760 fn is_training(&self) -> bool {
2761 *self.training.lock().unwrap()
2762 }
2763
2764 fn as_any(&self) -> Option<&dyn std::any::Any> {
2767 Some(self)
2768 }
2769}
2770
2771#[derive(Debug)]
2780struct BatchNorm1dBackward<T: Float> {
2781 input: Tensor<T>,
2782 x_hat: Tensor<T>,
2783 weight: Option<Tensor<T>>,
2784 bias: Option<Tensor<T>>,
2785 chan_var: Vec<f64>,
2786 eps: f64,
2787 affine: bool,
2788 is_training: bool,
2790 running_mean: Vec<f64>,
2792 running_var: Vec<f64>,
2794}
2795
2796impl<T: Float> GradFn<T> for BatchNorm1dBackward<T> {
2797 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2798 let shape = self.input.shape();
2799 let ndim = shape.len();
2800 let batch = shape[0];
2801 let channels = shape[1];
2802 let length = if ndim == 3 { shape[2] } else { 1 };
2803 let count = batch * length;
2804 let count_t = T::from(count).unwrap();
2805
2806 if self.input.is_cuda() && is_f32::<T>() {
2808 let weight_dev;
2809 let weight_buf = match self.weight.as_ref() {
2810 Some(w) => w,
2811 None => {
2812 weight_dev = ferrotorch_core::creation::ones::<T>(&[channels])?
2813 .to(self.input.device())?;
2814 &weight_dev
2815 }
2816 };
2817 if let Some(grads) = batch_norm_gpu_backward(
2818 &self.input,
2819 grad_output,
2820 weight_buf,
2821 &self.running_mean,
2822 &self.running_var,
2823 batch,
2824 channels,
2825 length,
2826 self.eps,
2827 self.is_training,
2828 self.affine,
2829 self.weight.as_ref().is_some_and(|w| w.requires_grad()),
2830 self.bias.as_ref().is_some_and(|b| b.requires_grad()),
2831 )? {
2832 return Ok(grads);
2833 }
2834 return Err(FerrotorchError::NotImplementedOnCuda {
2835 op: "BatchNorm1dBackward",
2836 });
2837 }
2838 if self.input.is_cuda() {
2839 return Err(FerrotorchError::NotImplementedOnCuda {
2840 op: "BatchNorm1dBackward",
2841 });
2842 }
2843 let go_data = grad_output.data()?;
2844 let x_hat_data = self.x_hat.data()?;
2845
2846 let weight_data = self.weight.as_ref().map(|w| w.data().unwrap().to_vec());
2847
2848 let mut grad_input = vec![zero::<T>(); self.input.numel()];
2849 let mut grad_weight = vec![zero::<T>(); channels];
2850 let mut grad_bias = vec![zero::<T>(); channels];
2851
2852 for c in 0..channels {
2853 let var_f64 = self.chan_var[c];
2854 let inv_std = T::from(1.0 / (var_f64 + self.eps).sqrt()).unwrap();
2855
2856 let mut dl_dx_hat_sum = zero::<T>();
2857 let mut dl_dx_hat_x_hat_sum = zero::<T>();
2858
2859 for b in 0..batch {
2860 let base = b * channels * length + c * length;
2861 for l in 0..length {
2862 let idx = base + l;
2863 let x_h = x_hat_data[idx];
2864 let go = go_data[idx];
2865
2866 let dl_dx_hat = if self.affine {
2867 go * weight_data.as_ref().unwrap()[c]
2868 } else {
2869 go
2870 };
2871
2872 dl_dx_hat_sum += dl_dx_hat;
2873 dl_dx_hat_x_hat_sum += dl_dx_hat * x_h;
2874
2875 if self.affine {
2876 grad_weight[c] += go * x_h;
2877 grad_bias[c] += go;
2878 }
2879 }
2880 }
2881
2882 let dl_dx_hat_mean = dl_dx_hat_sum / count_t;
2883 let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / count_t;
2884
2885 for b in 0..batch {
2886 let base = b * channels * length + c * length;
2887 for l in 0..length {
2888 let idx = base + l;
2889 let x_h = x_hat_data[idx];
2890 let go = go_data[idx];
2891
2892 let dl_dx_hat = if self.affine {
2893 go * weight_data.as_ref().unwrap()[c]
2894 } else {
2895 go
2896 };
2897
2898 grad_input[idx] =
2899 inv_std * (dl_dx_hat - dl_dx_hat_mean - x_h * dl_dx_hat_x_hat_mean);
2900 }
2901 }
2902 }
2903
2904 let grad_input_tensor = Tensor::from_storage(
2905 TensorStorage::cpu(grad_input),
2906 self.input.shape().to_vec(),
2907 false,
2908 )?;
2909
2910 let grad_weight_out = if self.affine {
2911 if let Some(ref w) = self.weight {
2912 if w.requires_grad() {
2913 Some(Tensor::from_storage(
2914 TensorStorage::cpu(grad_weight),
2915 vec![channels],
2916 false,
2917 )?)
2918 } else {
2919 None
2920 }
2921 } else {
2922 None
2923 }
2924 } else {
2925 None
2926 };
2927
2928 let grad_bias_out = if self.affine {
2929 if let Some(ref b) = self.bias {
2930 if b.requires_grad() {
2931 Some(Tensor::from_storage(
2932 TensorStorage::cpu(grad_bias),
2933 vec![channels],
2934 false,
2935 )?)
2936 } else {
2937 None
2938 }
2939 } else {
2940 None
2941 }
2942 } else {
2943 None
2944 };
2945
2946 if self.affine {
2952 Ok(vec![
2953 Some(grad_input_tensor),
2954 grad_weight_out,
2955 grad_bias_out,
2956 ])
2957 } else {
2958 Ok(vec![Some(grad_input_tensor)])
2959 }
2960 }
2961
2962 fn inputs(&self) -> Vec<&Tensor<T>> {
2963 let mut v: Vec<&Tensor<T>> = vec![&self.input];
2964 if let Some(ref w) = self.weight {
2965 v.push(w);
2966 }
2967 if let Some(ref b) = self.bias {
2968 v.push(b);
2969 }
2970 v
2971 }
2972
2973 fn name(&self) -> &'static str {
2974 "BatchNorm1dBackward"
2975 }
2976}
2977
2978pub struct BatchNorm3d<T: Float> {
3001 pub num_features: usize,
3003 pub eps: f64,
3005 pub momentum: f64,
3008 pub affine: bool,
3010 pub weight: Option<Parameter<T>>,
3012 pub bias: Option<Parameter<T>>,
3014 running_mean: Mutex<Vec<f64>>,
3016 running_var: Mutex<Vec<f64>>,
3018 num_batches_tracked: Mutex<usize>,
3020 training: Mutex<bool>,
3022}
3023
3024impl<T: Float> std::fmt::Debug for BatchNorm3d<T> {
3025 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3026 f.debug_struct("BatchNorm3d")
3027 .field("num_features", &self.num_features)
3028 .field("eps", &self.eps)
3029 .field("momentum", &self.momentum)
3030 .field("affine", &self.affine)
3031 .field("weight", &self.weight)
3032 .field("bias", &self.bias)
3033 .field("training", &self.training)
3034 .finish()
3035 }
3036}
3037
3038impl<T: Float> BatchNorm3d<T> {
3039 pub fn new(
3048 num_features: usize,
3049 eps: f64,
3050 momentum: f64,
3051 affine: bool,
3052 ) -> FerrotorchResult<Self> {
3053 if num_features == 0 {
3054 return Err(FerrotorchError::InvalidArgument {
3055 message: "BatchNorm3d: num_features must be positive".into(),
3056 });
3057 }
3058
3059 let weight = if affine {
3060 Some(Parameter::ones(&[num_features])?)
3061 } else {
3062 None
3063 };
3064
3065 let bias = if affine {
3066 Some(Parameter::zeros(&[num_features])?)
3067 } else {
3068 None
3069 };
3070
3071 Ok(Self {
3072 num_features,
3073 eps,
3074 momentum,
3075 affine,
3076 weight,
3077 bias,
3078 running_mean: Mutex::new(vec![0.0; num_features]),
3079 running_var: Mutex::new(vec![1.0; num_features]),
3080 num_batches_tracked: Mutex::new(0),
3081 training: Mutex::new(true),
3082 })
3083 }
3084
3085 pub fn running_mean(&self) -> Vec<f64> {
3087 self.running_mean.lock().unwrap().clone()
3088 }
3089
3090 pub fn running_var(&self) -> Vec<f64> {
3092 self.running_var.lock().unwrap().clone()
3093 }
3094
3095 pub fn num_batches_tracked(&self) -> usize {
3097 *self.num_batches_tracked.lock().unwrap()
3098 }
3099
3100 pub fn set_running_mean(&self, value: &[T]) -> FerrotorchResult<()> {
3106 if value.len() != self.num_features {
3107 return Err(FerrotorchError::ShapeMismatch {
3108 message: format!(
3109 "BatchNorm3d::set_running_mean: expected slice of length \
3110 num_features={}, got {}",
3111 self.num_features,
3112 value.len()
3113 ),
3114 });
3115 }
3116 for (i, x) in value.iter().enumerate() {
3117 if !num_traits::Float::is_finite(*x) {
3118 return Err(FerrotorchError::InvalidArgument {
3119 message: format!(
3120 "BatchNorm3d::set_running_mean: non-finite value at \
3121 index {i} (running_mean must be finite)"
3122 ),
3123 });
3124 }
3125 }
3126 let mut rm = self.running_mean.lock().unwrap();
3127 for (slot, x) in rm.iter_mut().zip(value.iter()) {
3128 *slot = x.to_f64().unwrap();
3129 }
3130 Ok(())
3131 }
3132
3133 pub fn set_running_var(&self, value: &[T]) -> FerrotorchResult<()> {
3140 if value.len() != self.num_features {
3141 return Err(FerrotorchError::ShapeMismatch {
3142 message: format!(
3143 "BatchNorm3d::set_running_var: expected slice of length \
3144 num_features={}, got {}",
3145 self.num_features,
3146 value.len()
3147 ),
3148 });
3149 }
3150 let zero_t = zero::<T>();
3151 for (i, x) in value.iter().enumerate() {
3152 if !num_traits::Float::is_finite(*x) {
3153 return Err(FerrotorchError::InvalidArgument {
3154 message: format!(
3155 "BatchNorm3d::set_running_var: non-finite value at \
3156 index {i} (running_var must be finite)"
3157 ),
3158 });
3159 }
3160 if *x < zero_t {
3161 return Err(FerrotorchError::InvalidArgument {
3162 message: format!(
3163 "BatchNorm3d::set_running_var: negative value {} at \
3164 index {i} (running_var must be non-negative)",
3165 x.to_f64().unwrap()
3166 ),
3167 });
3168 }
3169 }
3170 let mut rv = self.running_var.lock().unwrap();
3171 for (slot, x) in rv.iter_mut().zip(value.iter()) {
3172 *slot = x.to_f64().unwrap();
3173 }
3174 Ok(())
3175 }
3176
3177 pub fn set_num_batches_tracked(&self, value: usize) -> FerrotorchResult<()> {
3181 let mut nbt = self.num_batches_tracked.lock().unwrap();
3182 *nbt = value;
3183 Ok(())
3184 }
3185}
3186
3187impl<T: Float> Module<T> for BatchNorm3d<T> {
3188 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3189 let shape = input.shape().to_vec();
3190 if shape.len() != 5 {
3191 return Err(FerrotorchError::ShapeMismatch {
3192 message: format!(
3193 "BatchNorm3d: expected 5D input [B, C, D, H, W], got {:?}",
3194 shape
3195 ),
3196 });
3197 }
3198
3199 let batch = shape[0];
3200 let channels = shape[1];
3201 let depth = shape[2];
3202 let height = shape[3];
3203 let width = shape[4];
3204 let spatial = depth * height * width;
3205
3206 if channels != self.num_features {
3207 return Err(FerrotorchError::ShapeMismatch {
3208 message: format!(
3209 "BatchNorm3d: expected {} channels, got {}",
3210 self.num_features, channels
3211 ),
3212 });
3213 }
3214
3215 if batch == 0 {
3216 return Ok(input.clone());
3217 }
3218
3219 if *self.training.lock().unwrap() && batch * spatial <= 1 {
3225 return Err(FerrotorchError::InvalidArgument {
3226 message: format!(
3227 "Expected more than 1 value per channel when training, got input size {:?}",
3228 shape
3229 ),
3230 });
3231 }
3232
3233 if input.is_cuda() {
3235 if is_f32::<T>() {
3236 let is_training = *self.training.lock().unwrap();
3237 if let Some(handle) = batch_norm_gpu_forward(
3238 input,
3239 self.weight.as_ref().map(|w| w.tensor()),
3240 self.bias.as_ref().map(|b| b.tensor()),
3241 &self.running_mean,
3242 &self.running_var,
3243 &self.num_batches_tracked,
3244 self.momentum,
3245 self.eps,
3246 channels,
3247 spatial,
3248 is_training,
3249 )? {
3250 return if is_grad_enabled() && input.requires_grad() {
3251 let grad_fn = Arc::new(BatchNorm3dBackward {
3252 input: input.clone(),
3253 x_hat: Tensor::from_storage(
3254 TensorStorage::cpu(Vec::new()),
3255 vec![0],
3256 false,
3257 )?,
3258 weight: self.weight.as_ref().map(|w| w.tensor().clone()),
3259 bias: self.bias.as_ref().map(|b| b.tensor().clone()),
3260 chan_var: Vec::new(),
3261 eps: self.eps,
3262 affine: self.affine,
3263 is_training,
3264 running_mean: self.running_mean.lock().unwrap().clone(),
3265 running_var: self.running_var.lock().unwrap().clone(),
3266 });
3267 Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
3268 } else {
3269 Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
3270 };
3271 }
3272 }
3273 return Err(FerrotorchError::NotImplementedOnCuda {
3274 op: "BatchNorm3d::forward",
3275 });
3276 }
3277 let input_data = input.data()?;
3278 let eps_t = T::from(self.eps).unwrap();
3279
3280 let weight_data = self.weight.as_ref().map(|w| w.tensor().data().unwrap());
3281 let bias_data = self.bias.as_ref().map(|b| b.tensor().data().unwrap());
3282
3283 let is_training = *self.training.lock().unwrap();
3284
3285 let mut chan_mean = vec![zero::<T>(); channels];
3286 let mut chan_var = vec![zero::<T>(); channels];
3287
3288 if is_training {
3289 let count = batch * spatial;
3290 let count_t = T::from(count).unwrap();
3291
3292 for c in 0..channels {
3293 let mut sum = zero::<T>();
3294 for b in 0..batch {
3295 let base = b * channels * spatial + c * spatial;
3296 for s in 0..spatial {
3297 sum += input_data[base + s];
3298 }
3299 }
3300 chan_mean[c] = sum / count_t;
3301
3302 let mut var_sum = zero::<T>();
3303 for b in 0..batch {
3304 let base = b * channels * spatial + c * spatial;
3305 for s in 0..spatial {
3306 let d = input_data[base + s] - chan_mean[c];
3307 var_sum += d * d;
3308 }
3309 }
3310 chan_var[c] = var_sum / count_t;
3311 }
3312
3313 {
3315 let mut rm = self.running_mean.lock().unwrap();
3316 let mut rv = self.running_var.lock().unwrap();
3317 let mut nbt = self.num_batches_tracked.lock().unwrap();
3318 *nbt += 1;
3319
3320 let mom = self.momentum;
3321 let bessel = if count > 1 {
3322 count as f64 / (count as f64 - 1.0)
3323 } else {
3324 1.0
3325 };
3326
3327 for c in 0..channels {
3328 let batch_mean_f64 = chan_mean[c].to_f64().unwrap();
3329 let batch_var_f64 = chan_var[c].to_f64().unwrap();
3330
3331 rm[c] = (1.0 - mom) * rm[c] + mom * batch_mean_f64;
3332 rv[c] = (1.0 - mom) * rv[c] + mom * batch_var_f64 * bessel;
3333 }
3334 }
3335 } else {
3336 let rm = self.running_mean.lock().unwrap();
3337 let rv = self.running_var.lock().unwrap();
3338
3339 for c in 0..channels {
3340 chan_mean[c] = T::from(rm[c]).unwrap();
3341 chan_var[c] = T::from(rv[c]).unwrap();
3342 }
3343 }
3344
3345 let mut output = vec![zero::<T>(); input.numel()];
3346
3347 let mut inv_std = vec![zero::<T>(); channels];
3348 let need_x_hat = is_grad_enabled() && input.requires_grad();
3349 let mut x_hat_data = if need_x_hat {
3350 Vec::with_capacity(input.numel())
3351 } else {
3352 Vec::new()
3353 };
3354
3355 for c in 0..channels {
3356 inv_std[c] = (chan_var[c] + eps_t).sqrt().recip();
3357 }
3358
3359 for b in 0..batch {
3360 for c in 0..channels {
3361 let base = b * channels * spatial + c * spatial;
3362 for s in 0..spatial {
3363 let idx = base + s;
3364 let normed = (input_data[idx] - chan_mean[c]) * inv_std[c];
3365
3366 if need_x_hat {
3367 x_hat_data.push(normed);
3368 }
3369
3370 if self.affine {
3371 let w = weight_data.as_ref().unwrap();
3372 let bi = bias_data.as_ref().unwrap();
3373 output[idx] = normed * w[c] + bi[c];
3374 } else {
3375 output[idx] = normed;
3376 }
3377 }
3378 }
3379 }
3380
3381 let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
3382
3383 if is_grad_enabled() && input.requires_grad() {
3384 let weight_tensor = self.weight.as_ref().map(|w| w.tensor().clone());
3385 let bias_tensor = self.bias.as_ref().map(|b| b.tensor().clone());
3386
3387 let grad_fn = Arc::new(BatchNorm3dBackward {
3388 input: input.clone(),
3389 x_hat: Tensor::from_storage(TensorStorage::cpu(x_hat_data), shape.to_vec(), false)?,
3390 weight: weight_tensor,
3391 bias: bias_tensor,
3392 chan_var: chan_var.iter().map(|v| v.to_f64().unwrap()).collect(),
3393 eps: self.eps,
3394 affine: self.affine,
3395 is_training,
3396 running_mean: self.running_mean.lock().unwrap().clone(),
3397 running_var: self.running_var.lock().unwrap().clone(),
3398 });
3399
3400 Tensor::from_operation(
3401 TensorStorage::cpu(result.data()?.to_vec()),
3402 result.shape().to_vec(),
3403 grad_fn,
3404 )
3405 } else {
3406 Ok(result)
3407 }
3408 }
3409
3410 fn parameters(&self) -> Vec<&Parameter<T>> {
3411 match (&self.weight, &self.bias) {
3412 (Some(w), Some(b)) => vec![w, b],
3413 _ => vec![],
3414 }
3415 }
3416
3417 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
3418 match (&mut self.weight, &mut self.bias) {
3419 (Some(w), Some(b)) => vec![w, b],
3420 _ => vec![],
3421 }
3422 }
3423
3424 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
3425 match (&self.weight, &self.bias) {
3426 (Some(w), Some(b)) => vec![("weight".to_string(), w), ("bias".to_string(), b)],
3427 _ => vec![],
3428 }
3429 }
3430
3431 fn train(&mut self) {
3432 *self.training.lock().unwrap() = true;
3433 }
3434
3435 fn eval(&mut self) {
3436 *self.training.lock().unwrap() = false;
3437 }
3438
3439 fn is_training(&self) -> bool {
3440 *self.training.lock().unwrap()
3441 }
3442
3443 fn as_any(&self) -> Option<&dyn std::any::Any> {
3446 Some(self)
3447 }
3448}
3449
3450#[derive(Debug)]
3458struct BatchNorm3dBackward<T: Float> {
3459 input: Tensor<T>,
3460 x_hat: Tensor<T>,
3461 weight: Option<Tensor<T>>,
3462 bias: Option<Tensor<T>>,
3463 chan_var: Vec<f64>,
3464 eps: f64,
3465 affine: bool,
3466 is_training: bool,
3468 running_mean: Vec<f64>,
3470 running_var: Vec<f64>,
3472}
3473
3474impl<T: Float> GradFn<T> for BatchNorm3dBackward<T> {
3475 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
3476 let shape = self.input.shape();
3477 let batch = shape[0];
3478 let channels = shape[1];
3479 let spatial: usize = shape[2..].iter().product();
3480 let count = batch * spatial;
3481 let count_t = T::from(count).unwrap();
3482
3483 if self.input.is_cuda() && is_f32::<T>() {
3485 let weight_dev;
3486 let weight_buf = match self.weight.as_ref() {
3487 Some(w) => w,
3488 None => {
3489 weight_dev = ferrotorch_core::creation::ones::<T>(&[channels])?
3490 .to(self.input.device())?;
3491 &weight_dev
3492 }
3493 };
3494 if let Some(grads) = batch_norm_gpu_backward(
3495 &self.input,
3496 grad_output,
3497 weight_buf,
3498 &self.running_mean,
3499 &self.running_var,
3500 batch,
3501 channels,
3502 spatial,
3503 self.eps,
3504 self.is_training,
3505 self.affine,
3506 self.weight.as_ref().is_some_and(|w| w.requires_grad()),
3507 self.bias.as_ref().is_some_and(|b| b.requires_grad()),
3508 )? {
3509 return Ok(grads);
3510 }
3511 return Err(FerrotorchError::NotImplementedOnCuda {
3512 op: "BatchNorm3dBackward",
3513 });
3514 }
3515 if self.input.is_cuda() {
3516 return Err(FerrotorchError::NotImplementedOnCuda {
3517 op: "BatchNorm3dBackward",
3518 });
3519 }
3520 let go_data = grad_output.data()?;
3521 let x_hat_data = self.x_hat.data()?;
3522
3523 let weight_data = self.weight.as_ref().map(|w| w.data().unwrap().to_vec());
3524
3525 let mut grad_input = vec![zero::<T>(); self.input.numel()];
3526 let mut grad_weight = vec![zero::<T>(); channels];
3527 let mut grad_bias = vec![zero::<T>(); channels];
3528
3529 for c in 0..channels {
3530 let var_f64 = self.chan_var[c];
3531 let inv_std = T::from(1.0 / (var_f64 + self.eps).sqrt()).unwrap();
3532
3533 let mut dl_dx_hat_sum = zero::<T>();
3534 let mut dl_dx_hat_x_hat_sum = zero::<T>();
3535
3536 for b in 0..batch {
3537 let base = b * channels * spatial + c * spatial;
3538 for s in 0..spatial {
3539 let idx = base + s;
3540 let x_h = x_hat_data[idx];
3541 let go = go_data[idx];
3542
3543 let dl_dx_hat = if self.affine {
3544 go * weight_data.as_ref().unwrap()[c]
3545 } else {
3546 go
3547 };
3548
3549 dl_dx_hat_sum += dl_dx_hat;
3550 dl_dx_hat_x_hat_sum += dl_dx_hat * x_h;
3551
3552 if self.affine {
3553 grad_weight[c] += go * x_h;
3554 grad_bias[c] += go;
3555 }
3556 }
3557 }
3558
3559 let dl_dx_hat_mean = dl_dx_hat_sum / count_t;
3560 let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / count_t;
3561
3562 for b in 0..batch {
3563 let base = b * channels * spatial + c * spatial;
3564 for s in 0..spatial {
3565 let idx = base + s;
3566 let x_h = x_hat_data[idx];
3567 let go = go_data[idx];
3568
3569 let dl_dx_hat = if self.affine {
3570 go * weight_data.as_ref().unwrap()[c]
3571 } else {
3572 go
3573 };
3574
3575 grad_input[idx] =
3576 inv_std * (dl_dx_hat - dl_dx_hat_mean - x_h * dl_dx_hat_x_hat_mean);
3577 }
3578 }
3579 }
3580
3581 let grad_input_tensor = Tensor::from_storage(
3582 TensorStorage::cpu(grad_input),
3583 self.input.shape().to_vec(),
3584 false,
3585 )?;
3586
3587 let grad_weight_out = if self.affine {
3588 if let Some(ref w) = self.weight {
3589 if w.requires_grad() {
3590 Some(Tensor::from_storage(
3591 TensorStorage::cpu(grad_weight),
3592 vec![channels],
3593 false,
3594 )?)
3595 } else {
3596 None
3597 }
3598 } else {
3599 None
3600 }
3601 } else {
3602 None
3603 };
3604
3605 let grad_bias_out = if self.affine {
3606 if let Some(ref b) = self.bias {
3607 if b.requires_grad() {
3608 Some(Tensor::from_storage(
3609 TensorStorage::cpu(grad_bias),
3610 vec![channels],
3611 false,
3612 )?)
3613 } else {
3614 None
3615 }
3616 } else {
3617 None
3618 }
3619 } else {
3620 None
3621 };
3622
3623 if self.affine {
3629 Ok(vec![
3630 Some(grad_input_tensor),
3631 grad_weight_out,
3632 grad_bias_out,
3633 ])
3634 } else {
3635 Ok(vec![Some(grad_input_tensor)])
3636 }
3637 }
3638
3639 fn inputs(&self) -> Vec<&Tensor<T>> {
3640 let mut v: Vec<&Tensor<T>> = vec![&self.input];
3641 if let Some(ref w) = self.weight {
3642 v.push(w);
3643 }
3644 if let Some(ref b) = self.bias {
3645 v.push(b);
3646 }
3647 v
3648 }
3649
3650 fn name(&self) -> &'static str {
3651 "BatchNorm3dBackward"
3652 }
3653}
3654
3655#[derive(Debug, Clone)]
3679pub struct LocalResponseNorm {
3680 pub size: usize,
3681 pub alpha: f64,
3682 pub beta: f64,
3683 pub k: f64,
3684 training: bool,
3688}
3689
3690impl LocalResponseNorm {
3691 pub fn new(size: usize, alpha: f64, beta: f64, k: f64) -> FerrotorchResult<Self> {
3700 if size == 0 {
3701 return Err(FerrotorchError::InvalidArgument {
3702 message: "LocalResponseNorm: size must be positive".into(),
3703 });
3704 }
3705 Ok(Self {
3706 size,
3707 alpha,
3708 beta,
3709 k,
3710 training: true,
3711 })
3712 }
3713
3714 pub fn default_params(size: usize) -> FerrotorchResult<Self> {
3716 Self::new(size, 1e-4, 0.75, 1.0)
3717 }
3718}
3719
3720impl<T: Float> Module<T> for LocalResponseNorm {
3721 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3722 let shape = input.shape().to_vec();
3723 if shape.len() < 3 {
3724 return Err(FerrotorchError::ShapeMismatch {
3725 message: format!(
3726 "LocalResponseNorm: expected at least 3D input [B, C, ...], got {:?}",
3727 shape
3728 ),
3729 });
3730 }
3731
3732 let batch = shape[0];
3733 let channels = shape[1];
3734 let spatial: usize = shape[2..].iter().product();
3735
3736 if input.is_cuda() && is_f32::<T>() {
3741 if let Some(backend) = gpu_backend() {
3742 let (out_h, denom_h) = backend.local_response_norm_f32(
3743 input.gpu_handle()?,
3744 batch,
3745 channels,
3746 spatial,
3747 self.size,
3748 self.alpha as f32,
3749 self.beta as f32,
3750 self.k as f32,
3751 )?;
3752 let denom_gpu =
3753 Tensor::from_storage(TensorStorage::gpu(denom_h), shape.clone(), false)?;
3754 return if is_grad_enabled() && input.requires_grad() {
3755 Tensor::from_operation(
3756 TensorStorage::gpu(out_h),
3757 shape,
3758 Arc::new(LocalResponseNormBackward {
3759 input: input.clone(),
3760 denom: Vec::new(),
3761 denom_gpu: Some(denom_gpu),
3762 size: self.size,
3763 alpha: self.alpha,
3764 beta: self.beta,
3765 }),
3766 )
3767 } else {
3768 Tensor::from_storage(TensorStorage::gpu(out_h), shape, false)
3769 };
3770 }
3771 return Err(FerrotorchError::NotImplementedOnCuda {
3772 op: "LocalResponseNorm::forward",
3773 });
3774 }
3775 if input.is_cuda() {
3776 return Err(FerrotorchError::NotImplementedOnCuda {
3777 op: "LocalResponseNorm::forward",
3778 });
3779 }
3780 let input_data = input.data()?;
3781 let alpha_t = T::from(self.alpha).unwrap();
3782 let beta_t = T::from(self.beta).unwrap();
3783 let k_t = T::from(self.k).unwrap();
3784 let size_t = T::from(self.size).unwrap();
3785 let half = self.size / 2;
3796 let upper = self.size - half; let mut output = vec![zero::<T>(); input.numel()];
3799
3800 let mut denom = vec![zero::<T>(); input.numel()];
3803
3804 for b in 0..batch {
3809 for c in 0..channels {
3810 let c_start = c.saturating_sub(half);
3811 let c_end = (c + upper).min(channels);
3812
3813 for s in 0..spatial {
3814 let mut sq_sum = zero::<T>();
3815 for j in c_start..c_end {
3816 let jidx = b * channels * spatial + j * spatial + s;
3817 sq_sum += input_data[jidx] * input_data[jidx];
3818 }
3819
3820 let idx = b * channels * spatial + c * spatial + s;
3821 let d = sq_sum / size_t * alpha_t + k_t;
3822 denom[idx] = d;
3823 output[idx] = input_data[idx] / d.powf(beta_t);
3824 }
3825 }
3826 }
3827
3828 let storage = TensorStorage::cpu(output);
3829
3830 if is_grad_enabled() && input.requires_grad() {
3831 Tensor::from_operation(
3832 storage,
3833 shape,
3834 Arc::new(LocalResponseNormBackward {
3835 input: input.clone(),
3836 denom,
3837 denom_gpu: None,
3838 size: self.size,
3839 alpha: self.alpha,
3840 beta: self.beta,
3841 }),
3842 )
3843 } else {
3844 Tensor::from_storage(storage, shape, false)
3845 }
3846 }
3847
3848 fn parameters(&self) -> Vec<&Parameter<T>> {
3849 vec![]
3850 }
3851
3852 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
3853 vec![]
3854 }
3855
3856 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
3857 vec![]
3858 }
3859
3860 fn train(&mut self) {
3861 self.training = true;
3862 }
3863
3864 fn eval(&mut self) {
3865 self.training = false;
3866 }
3867
3868 fn is_training(&self) -> bool {
3869 self.training
3870 }
3871}
3872
3873#[derive(Debug)]
3891struct LocalResponseNormBackward<T: Float> {
3892 input: Tensor<T>,
3893 denom: Vec<T>,
3895 denom_gpu: Option<Tensor<T>>,
3898 size: usize,
3899 alpha: f64,
3900 beta: f64,
3901}
3902
3903impl<T: Float> GradFn<T> for LocalResponseNormBackward<T> {
3904 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
3905 if !self.input.requires_grad() {
3906 return Ok(vec![None]);
3907 }
3908
3909 if self.input.is_cuda() && is_f32::<T>() {
3912 if let (Some(backend), Some(denom_gpu)) = (gpu_backend(), self.denom_gpu.as_ref()) {
3913 let shape = self.input.shape();
3914 let batch = shape[0];
3915 let channels = shape[1];
3916 let spatial: usize = shape[2..].iter().product();
3917 let gi_h = backend.local_response_norm_backward_f32(
3918 self.input.gpu_handle()?,
3919 grad_output.gpu_handle()?,
3920 denom_gpu.gpu_handle()?,
3921 batch,
3922 channels,
3923 spatial,
3924 self.size,
3925 self.alpha as f32,
3926 self.beta as f32,
3927 )?;
3928 let grad_input =
3929 Tensor::from_storage(TensorStorage::gpu(gi_h), shape.to_vec(), false)?;
3930 return Ok(vec![Some(grad_input)]);
3931 }
3932 return Err(FerrotorchError::NotImplementedOnCuda {
3933 op: "LocalResponseNormBackward",
3934 });
3935 }
3936 if self.input.is_cuda() {
3937 return Err(FerrotorchError::NotImplementedOnCuda {
3938 op: "LocalResponseNormBackward",
3939 });
3940 }
3941
3942 let shape = self.input.shape();
3943 let batch = shape[0];
3944 let channels = shape[1];
3945 let spatial: usize = shape[2..].iter().product();
3946
3947 let input_data = self.input.data()?;
3948 let go_data = grad_output.data()?;
3949 let alpha_t = T::from(self.alpha).unwrap();
3950 let beta_t = T::from(self.beta).unwrap();
3951 let size_t = T::from(self.size).unwrap();
3952 let two = T::from(2.0).unwrap();
3953 let half = self.size / 2;
3962 let upper = self.size - half; let mut grad_input = vec![zero::<T>(); self.input.numel()];
3965
3966 for b in 0..batch {
3967 for i_c in 0..channels {
3968 for s in 0..spatial {
3969 let i_idx = b * channels * spatial + i_c * spatial + s;
3970
3971 let term1 = self.denom[i_idx].powf(-beta_t) * go_data[i_idx];
3973
3974 let c_start = (i_c + 1).saturating_sub(upper);
3978 let c_end = (i_c + half + 1).min(channels);
3979
3980 let mut cross_sum = zero::<T>();
3981 for c in c_start..c_end {
3982 let c_idx = b * channels * spatial + c * spatial + s;
3983 cross_sum += go_data[c_idx]
3984 * input_data[c_idx]
3985 * self.denom[c_idx].powf(-beta_t - T::from(1.0).unwrap());
3986 }
3987
3988 grad_input[i_idx] =
3989 term1 - two * beta_t * alpha_t / size_t * input_data[i_idx] * cross_sum;
3990 }
3991 }
3992 }
3993
3994 let grad_tensor = Tensor::from_storage(
3995 TensorStorage::cpu(grad_input),
3996 self.input.shape().to_vec(),
3997 false,
3998 )?;
3999 Ok(vec![Some(grad_tensor)])
4000 }
4001
4002 fn inputs(&self) -> Vec<&Tensor<T>> {
4003 vec![&self.input]
4004 }
4005
4006 fn name(&self) -> &'static str {
4007 "LocalResponseNormBackward"
4008 }
4009}
4010
4011#[derive(Debug)]
4030struct InstanceNormInner<T: Float> {
4031 num_features: usize,
4033 eps: f64,
4035 affine: bool,
4037 weight: Parameter<T>,
4039 bias: Parameter<T>,
4041 training: bool,
4042}
4043
4044impl<T: Float> InstanceNormInner<T> {
4045 fn new(num_features: usize, eps: f64, affine: bool) -> FerrotorchResult<Self> {
4046 if num_features == 0 {
4047 return Err(FerrotorchError::InvalidArgument {
4048 message: "InstanceNorm: num_features must be positive".into(),
4049 });
4050 }
4051
4052 let weight = Parameter::ones(&[num_features])?;
4053 let bias = Parameter::zeros(&[num_features])?;
4054
4055 Ok(Self {
4056 num_features,
4057 eps,
4058 affine,
4059 weight,
4060 bias,
4061 training: true,
4062 })
4063 }
4064
4065 fn forward_impl(&self, input: &Tensor<T>, expected_ndim: usize) -> FerrotorchResult<Tensor<T>> {
4068 let label = match expected_ndim {
4069 3 => "InstanceNorm1d",
4070 4 => "InstanceNorm2d",
4071 _ => "InstanceNorm3d",
4072 };
4073 let shape = input.shape().to_vec();
4074
4075 if shape.len() != expected_ndim {
4076 return Err(FerrotorchError::ShapeMismatch {
4077 message: format!("{label}: expected {expected_ndim}D input, got {:?}", shape),
4078 });
4079 }
4080
4081 let batch = shape[0];
4082 let channels = shape[1];
4083 if channels != self.num_features {
4084 return Err(FerrotorchError::ShapeMismatch {
4085 message: format!(
4086 "{label}: expected {} channels, got {}",
4087 self.num_features, channels
4088 ),
4089 });
4090 }
4091
4092 let spatial: usize = shape[2..].iter().product();
4093 if spatial == 0 {
4094 return Ok(input.clone());
4095 }
4096
4097 if input.is_cuda() {
4105 if let Some(backend) = gpu_backend() {
4106 let eps_f32 = self.eps as f32;
4107 let handle = backend.group_norm_f32(
4108 input.gpu_handle()?,
4109 self.weight.tensor().gpu_handle()?,
4110 self.bias.tensor().gpu_handle()?,
4111 batch,
4112 channels,
4113 channels, spatial,
4115 eps_f32,
4116 )?;
4117 return if is_grad_enabled() && input.requires_grad() {
4118 let grad_fn = Arc::new(InstanceNormBackward {
4119 input: input.clone(),
4120 weight: self.weight.tensor().clone(),
4121 bias: self.bias.tensor().clone(),
4122 num_features: self.num_features,
4123 eps: self.eps,
4124 affine: self.affine,
4125 });
4126 Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
4127 } else {
4128 Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
4129 };
4130 }
4131 return Err(FerrotorchError::NotImplementedOnCuda {
4133 op: "InstanceNorm::forward",
4134 });
4135 }
4136 let input_data = input.data()?;
4137 let eps_t = T::from(self.eps).unwrap();
4138 let n_t = T::from(spatial).unwrap();
4139
4140 let weight_data = self.weight.tensor().data()?;
4141 let bias_data = self.bias.tensor().data()?;
4142
4143 let mut output = vec![zero::<T>(); input.numel()];
4144
4145 for b in 0..batch {
4146 for c in 0..channels {
4147 let base = b * channels * spatial + c * spatial;
4148 let slice = &input_data[base..base + spatial];
4149
4150 let mean = slice.iter().copied().fold(zero::<T>(), |a, x| a + x) / n_t;
4152 let var = slice.iter().copied().fold(zero::<T>(), |a, x| {
4153 let d = x - mean;
4154 a + d * d
4155 }) / n_t;
4156 let inv_std = (var + eps_t).sqrt().recip();
4157
4158 for s in 0..spatial {
4159 let idx = base + s;
4160 let normed = (input_data[idx] - mean) * inv_std;
4161 if self.affine {
4162 output[idx] = normed * weight_data[c] + bias_data[c];
4163 } else {
4164 output[idx] = normed;
4165 }
4166 }
4167 }
4168 }
4169
4170 let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
4171
4172 if is_grad_enabled() && input.requires_grad() {
4173 let grad_fn = Arc::new(InstanceNormBackward {
4174 input: input.clone(),
4175 weight: self.weight.tensor().clone(),
4176 bias: self.bias.tensor().clone(),
4177 num_features: self.num_features,
4178 eps: self.eps,
4179 affine: self.affine,
4180 });
4181 Tensor::from_operation(
4182 TensorStorage::cpu(result.data()?.to_vec()),
4183 result.shape().to_vec(),
4184 grad_fn,
4185 )
4186 } else {
4187 Ok(result)
4188 }
4189 }
4190}
4191
4192#[derive(Debug)]
4201struct InstanceNormBackward<T: Float> {
4202 input: Tensor<T>,
4203 weight: Tensor<T>,
4204 bias: Tensor<T>,
4205 num_features: usize,
4206 eps: f64,
4207 affine: bool,
4208}
4209
4210impl<T: Float> GradFn<T> for InstanceNormBackward<T> {
4211 fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
4212 let shape = self.input.shape();
4213 let batch = shape[0];
4214 let channels = shape[1];
4215 let spatial: usize = shape[2..].iter().product();
4216 let n_t = T::from(spatial).unwrap();
4217 let eps_t = T::from(self.eps).unwrap();
4218
4219 if self.input.is_cuda() && is_f32::<T>() {
4229 if let Some(grads) = instance_norm_gpu_backward(
4230 &self.input,
4231 grad_output,
4232 &self.weight,
4233 batch,
4234 channels,
4235 spatial,
4236 self.eps,
4237 self.affine,
4238 self.weight.requires_grad(),
4239 self.bias.requires_grad(),
4240 )? {
4241 return Ok(grads);
4242 }
4243 return Err(FerrotorchError::NotImplementedOnCuda {
4244 op: "InstanceNormBackward",
4245 });
4246 }
4247 if self.input.is_cuda() {
4248 return Err(FerrotorchError::NotImplementedOnCuda {
4249 op: "InstanceNormBackward",
4250 });
4251 }
4252 let input_data = self.input.data()?;
4253 let go_data = grad_output.data()?;
4254 let weight_data = self.weight.data()?;
4255
4256 let mut grad_input = vec![zero::<T>(); self.input.numel()];
4257 let mut grad_weight = vec![zero::<T>(); self.num_features];
4258 let mut grad_bias = vec![zero::<T>(); self.num_features];
4259
4260 for b in 0..batch {
4261 for c in 0..channels {
4262 let base = b * channels * spatial + c * spatial;
4263 let x_slice = &input_data[base..base + spatial];
4264 let go_slice = &go_data[base..base + spatial];
4265
4266 let mean = x_slice.iter().copied().fold(zero::<T>(), |a, x| a + x) / n_t;
4268 let var = x_slice.iter().copied().fold(zero::<T>(), |a, x| {
4269 let d = x - mean;
4270 a + d * d
4271 }) / n_t;
4272 let inv_std = (var + eps_t).sqrt().recip();
4273
4274 let mut dl_dx_hat_sum = zero::<T>();
4276 let mut dl_dx_hat_x_hat_sum = zero::<T>();
4277
4278 for s in 0..spatial {
4279 let x_hat = (x_slice[s] - mean) * inv_std;
4280 let dl_dx_hat = if self.affine {
4281 go_slice[s] * weight_data[c]
4282 } else {
4283 go_slice[s]
4284 };
4285 dl_dx_hat_sum += dl_dx_hat;
4286 dl_dx_hat_x_hat_sum += dl_dx_hat * x_hat;
4287
4288 if self.affine {
4289 grad_weight[c] += go_slice[s] * x_hat;
4290 grad_bias[c] += go_slice[s];
4291 }
4292 }
4293
4294 let dl_dx_hat_mean = dl_dx_hat_sum / n_t;
4295 let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / n_t;
4296
4297 for s in 0..spatial {
4298 let x_hat = (x_slice[s] - mean) * inv_std;
4299 let dl_dx_hat = if self.affine {
4300 go_slice[s] * weight_data[c]
4301 } else {
4302 go_slice[s]
4303 };
4304 grad_input[base + s] =
4305 inv_std * (dl_dx_hat - dl_dx_hat_mean - x_hat * dl_dx_hat_x_hat_mean);
4306 }
4307 }
4308 }
4309
4310 let grad_input_tensor = Tensor::from_storage(
4311 TensorStorage::cpu(grad_input),
4312 self.input.shape().to_vec(),
4313 false,
4314 )?;
4315
4316 let grad_weight_out = if self.affine && self.weight.requires_grad() {
4317 Some(Tensor::from_storage(
4318 TensorStorage::cpu(grad_weight),
4319 vec![self.num_features],
4320 false,
4321 )?)
4322 } else {
4323 None
4324 };
4325
4326 let grad_bias_out = if self.affine && self.bias.requires_grad() {
4327 Some(Tensor::from_storage(
4328 TensorStorage::cpu(grad_bias),
4329 vec![self.num_features],
4330 false,
4331 )?)
4332 } else {
4333 None
4334 };
4335
4336 Ok(vec![
4337 Some(grad_input_tensor),
4338 grad_weight_out,
4339 grad_bias_out,
4340 ])
4341 }
4342
4343 fn inputs(&self) -> Vec<&Tensor<T>> {
4344 vec![&self.input, &self.weight, &self.bias]
4345 }
4346
4347 fn name(&self) -> &'static str {
4348 "InstanceNormBackward"
4349 }
4350}
4351
4352#[derive(Debug)]
4363pub struct InstanceNorm1d<T: Float> {
4364 inner: InstanceNormInner<T>,
4365}
4366
4367impl<T: Float> InstanceNorm1d<T> {
4368 pub fn new(num_features: usize, eps: f64, affine: bool) -> FerrotorchResult<Self> {
4376 Ok(Self {
4377 inner: InstanceNormInner::new(num_features, eps, affine)?,
4378 })
4379 }
4380}
4381
4382impl<T: Float> Module<T> for InstanceNorm1d<T> {
4383 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
4384 self.inner.forward_impl(input, 3)
4385 }
4386
4387 fn parameters(&self) -> Vec<&Parameter<T>> {
4388 if self.inner.affine {
4389 vec![&self.inner.weight, &self.inner.bias]
4390 } else {
4391 vec![]
4392 }
4393 }
4394
4395 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4396 if self.inner.affine {
4397 vec![&mut self.inner.weight, &mut self.inner.bias]
4398 } else {
4399 vec![]
4400 }
4401 }
4402
4403 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4404 if self.inner.affine {
4405 vec![
4406 ("weight".to_string(), &self.inner.weight),
4407 ("bias".to_string(), &self.inner.bias),
4408 ]
4409 } else {
4410 vec![]
4411 }
4412 }
4413
4414 fn train(&mut self) {
4415 self.inner.training = true;
4416 }
4417
4418 fn eval(&mut self) {
4419 self.inner.training = false;
4420 }
4421
4422 fn is_training(&self) -> bool {
4423 self.inner.training
4424 }
4425}
4426
4427#[derive(Debug)]
4438pub struct InstanceNorm2d<T: Float> {
4439 inner: InstanceNormInner<T>,
4440}
4441
4442impl<T: Float> InstanceNorm2d<T> {
4443 pub fn new(num_features: usize, eps: f64, affine: bool) -> FerrotorchResult<Self> {
4451 Ok(Self {
4452 inner: InstanceNormInner::new(num_features, eps, affine)?,
4453 })
4454 }
4455}
4456
4457impl<T: Float> Module<T> for InstanceNorm2d<T> {
4458 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
4459 self.inner.forward_impl(input, 4)
4460 }
4461
4462 fn parameters(&self) -> Vec<&Parameter<T>> {
4463 if self.inner.affine {
4464 vec![&self.inner.weight, &self.inner.bias]
4465 } else {
4466 vec![]
4467 }
4468 }
4469
4470 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4471 if self.inner.affine {
4472 vec![&mut self.inner.weight, &mut self.inner.bias]
4473 } else {
4474 vec![]
4475 }
4476 }
4477
4478 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4479 if self.inner.affine {
4480 vec![
4481 ("weight".to_string(), &self.inner.weight),
4482 ("bias".to_string(), &self.inner.bias),
4483 ]
4484 } else {
4485 vec![]
4486 }
4487 }
4488
4489 fn train(&mut self) {
4490 self.inner.training = true;
4491 }
4492
4493 fn eval(&mut self) {
4494 self.inner.training = false;
4495 }
4496
4497 fn is_training(&self) -> bool {
4498 self.inner.training
4499 }
4500}
4501
4502#[derive(Debug)]
4513pub struct InstanceNorm3d<T: Float> {
4514 inner: InstanceNormInner<T>,
4515}
4516
4517impl<T: Float> InstanceNorm3d<T> {
4518 pub fn new(num_features: usize, eps: f64, affine: bool) -> FerrotorchResult<Self> {
4526 Ok(Self {
4527 inner: InstanceNormInner::new(num_features, eps, affine)?,
4528 })
4529 }
4530}
4531
4532impl<T: Float> Module<T> for InstanceNorm3d<T> {
4533 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
4534 self.inner.forward_impl(input, 5)
4535 }
4536
4537 fn parameters(&self) -> Vec<&Parameter<T>> {
4538 if self.inner.affine {
4539 vec![&self.inner.weight, &self.inner.bias]
4540 } else {
4541 vec![]
4542 }
4543 }
4544
4545 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4546 if self.inner.affine {
4547 vec![&mut self.inner.weight, &mut self.inner.bias]
4548 } else {
4549 vec![]
4550 }
4551 }
4552
4553 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4554 if self.inner.affine {
4555 vec![
4556 ("weight".to_string(), &self.inner.weight),
4557 ("bias".to_string(), &self.inner.bias),
4558 ]
4559 } else {
4560 vec![]
4561 }
4562 }
4563
4564 fn train(&mut self) {
4565 self.inner.training = true;
4566 }
4567
4568 fn eval(&mut self) {
4569 self.inner.training = false;
4570 }
4571
4572 fn is_training(&self) -> bool {
4573 self.inner.training
4574 }
4575}
4576
4577#[cfg(test)]
4582mod tests {
4583 use super::*;
4584 use ferrotorch_core::autograd::no_grad::no_grad;
4585
4586 fn leaf(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
4588 Tensor::from_storage(
4589 TensorStorage::cpu(data.to_vec()),
4590 shape.to_vec(),
4591 requires_grad,
4592 )
4593 .unwrap()
4594 }
4595
4596 #[test]
4601 fn test_layer_norm_parameter_shapes() {
4602 let ln = LayerNorm::<f32>::new(vec![8], 1e-5, true).unwrap();
4603 let params = ln.parameters();
4604 assert_eq!(params.len(), 2);
4605 assert_eq!(params[0].shape(), &[8]); assert_eq!(params[1].shape(), &[8]); }
4608
4609 #[test]
4610 fn test_layer_norm_no_affine_no_params() {
4611 let ln = LayerNorm::<f32>::new(vec![8], 1e-5, false).unwrap();
4612 assert_eq!(ln.parameters().len(), 0);
4613 }
4614
4615 #[test]
4616 fn test_layer_norm_forward_zero_mean_unit_var() {
4617 let data: Vec<f32> = vec![
4620 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, ];
4623 let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 8], false).unwrap();
4624
4625 let ln = LayerNorm::<f32>::new(vec![8], 1e-5, true).unwrap();
4626 let output = ln.forward(&input).unwrap();
4627 let out_data = output.data().unwrap();
4628
4629 for row in 0..2 {
4630 let start = row * 8;
4631 let end = start + 8;
4632 let row_data = &out_data[start..end];
4633
4634 let mean: f32 = row_data.iter().sum::<f32>() / 8.0;
4635 let var: f32 = row_data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / 8.0;
4636
4637 assert!(mean.abs() < 1e-5, "row {row} mean = {mean}, expected ~0");
4638 assert!(
4639 (var - 1.0).abs() < 0.05,
4640 "row {row} var = {var}, expected ~1"
4641 );
4642 }
4643 }
4644
4645 #[test]
4646 fn test_layer_norm_forward_shape_preserved() {
4647 let input =
4648 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 24]), vec![2, 3, 4], false)
4649 .unwrap();
4650
4651 let ln = LayerNorm::<f32>::new(vec![4], 1e-5, true).unwrap();
4652 let output = ln.forward(&input).unwrap();
4653 assert_eq!(output.shape(), &[2, 3, 4]);
4654 }
4655
4656 #[test]
4657 fn test_layer_norm_shape_mismatch() {
4658 let input =
4659 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 12]), vec![3, 4], false)
4660 .unwrap();
4661
4662 let ln = LayerNorm::<f32>::new(vec![5], 1e-5, true).unwrap();
4663 assert!(ln.forward(&input).is_err());
4664 }
4665
4666 #[test]
4667 fn test_layer_norm_empty_normalized_shape() {
4668 assert!(LayerNorm::<f32>::new(vec![], 1e-5, true).is_err());
4669 }
4670
4671 #[test]
4672 fn test_layer_norm_has_grad_fn_when_input_requires_grad() {
4673 let input = Tensor::<f32>::from_storage(
4674 TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
4675 vec![1, 4],
4676 true,
4677 )
4678 .unwrap();
4679
4680 let ln = LayerNorm::<f32>::new(vec![4], 1e-5, true).unwrap();
4681 let output = ln.forward(&input).unwrap();
4682 assert!(output.grad_fn().is_some());
4683 assert_eq!(output.grad_fn().unwrap().name(), "LayerNormBackward");
4684 }
4685
4686 #[test]
4687 fn test_layer_norm_no_grad_fn_in_no_grad_context() {
4688 let input = Tensor::<f32>::from_storage(
4689 TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
4690 vec![1, 4],
4691 true,
4692 )
4693 .unwrap();
4694
4695 let ln = LayerNorm::<f32>::new(vec![4], 1e-5, true).unwrap();
4696 let output = no_grad(|| ln.forward(&input)).unwrap();
4697 assert!(output.grad_fn().is_none());
4698 }
4699
4700 #[test]
4701 fn test_layer_norm_backward_gradient_check() -> FerrotorchResult<()> {
4702 let h = 1e-7;
4705 let hidden = 4;
4706 let input_data = vec![1.0f64, -0.5, 2.0, 0.3];
4707
4708 let ln = LayerNorm::<f64>::new(vec![hidden], 1e-5, true)?;
4709
4710 let input = leaf(&input_data, &[1, hidden], true);
4712 let output = ln.forward(&input)?;
4713 let out_data = output.data()?.to_vec();
4714 let total: f64 = out_data.iter().sum();
4715
4716 let sum_gf = Arc::new(SumBackwardHelper {
4718 input: output.clone(),
4719 });
4720 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf)?;
4721 loss.backward()?;
4722
4723 let analytic_grad = input.grad().unwrap().unwrap();
4724 let analytic = analytic_grad.data()?.to_vec();
4725
4726 for i in 0..hidden {
4728 let mut data_plus = input_data.clone();
4729 data_plus[i] += h;
4730 let inp_plus = leaf(&data_plus, &[1, hidden], false);
4731 let out_plus = no_grad(|| ln.forward(&inp_plus)).unwrap();
4732 let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
4733
4734 let mut data_minus = input_data.clone();
4735 data_minus[i] -= h;
4736 let inp_minus = leaf(&data_minus, &[1, hidden], false);
4737 let out_minus = no_grad(|| ln.forward(&inp_minus)).unwrap();
4738 let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
4739
4740 let numerical = (sum_plus - sum_minus) / (2.0 * h);
4741 assert!(
4742 (numerical - analytic[i]).abs() < 1e-4,
4743 "LayerNorm grad[{i}]: numerical={numerical}, analytic={}",
4744 analytic[i]
4745 );
4746 }
4747
4748 Ok(())
4749 }
4750
4751 #[test]
4752 fn test_layer_norm_named_parameters() {
4753 let ln = LayerNorm::<f32>::new(vec![16], 1e-5, true).unwrap();
4754 let named = ln.named_parameters();
4755 assert_eq!(named.len(), 2);
4756 assert_eq!(named[0].0, "weight");
4757 assert_eq!(named[1].0, "bias");
4758 }
4759
4760 #[test]
4761 fn test_layer_norm_train_eval() {
4762 let mut ln = LayerNorm::<f32>::new(vec![8], 1e-5, true).unwrap();
4763 assert!(ln.is_training());
4764 ln.eval();
4765 assert!(!ln.is_training());
4766 ln.train();
4767 assert!(ln.is_training());
4768 }
4769
4770 #[test]
4775 fn test_group_norm_parameter_shapes() {
4776 let gn = GroupNorm::<f32>::new(4, 8, 1e-5, true).unwrap();
4777 let params = gn.parameters();
4778 assert_eq!(params.len(), 2);
4779 assert_eq!(params[0].shape(), &[8]); assert_eq!(params[1].shape(), &[8]); }
4782
4783 #[test]
4784 fn test_group_norm_no_affine_no_params() {
4785 let gn = GroupNorm::<f32>::new(2, 4, 1e-5, false).unwrap();
4786 assert_eq!(gn.parameters().len(), 0);
4787 }
4788
4789 #[test]
4790 fn test_group_norm_invalid_groups() {
4791 assert!(GroupNorm::<f32>::new(0, 8, 1e-5, true).is_err());
4792 assert!(GroupNorm::<f32>::new(3, 8, 1e-5, true).is_err()); }
4794
4795 #[test]
4796 fn test_group_norm_forward_zero_mean_unit_var() {
4797 let data: Vec<f32> = vec![
4800 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
4802 ];
4803 let input = Tensor::from_storage(TensorStorage::cpu(data), vec![1, 4, 2], false).unwrap();
4805
4806 let gn = GroupNorm::<f32>::new(2, 4, 1e-5, true).unwrap();
4807 let output = gn.forward(&input).unwrap();
4808 let out_data = output.data().unwrap();
4809
4810 let group0: Vec<f32> = out_data[0..4].to_vec();
4812 let mean0: f32 = group0.iter().sum::<f32>() / 4.0;
4813 let var0: f32 = group0.iter().map(|&x| (x - mean0).powi(2)).sum::<f32>() / 4.0;
4814 assert!(mean0.abs() < 1e-5, "group0 mean = {mean0}");
4815 assert!((var0 - 1.0).abs() < 0.05, "group0 var = {var0}");
4816
4817 let group1: Vec<f32> = out_data[4..8].to_vec();
4819 let mean1: f32 = group1.iter().sum::<f32>() / 4.0;
4820 let var1: f32 = group1.iter().map(|&x| (x - mean1).powi(2)).sum::<f32>() / 4.0;
4821 assert!(mean1.abs() < 1e-5, "group1 mean = {mean1}");
4822 assert!((var1 - 1.0).abs() < 0.05, "group1 var = {var1}");
4823 }
4824
4825 #[test]
4826 fn test_group_norm_forward_shape_preserved() {
4827 let input =
4828 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 48]), vec![2, 4, 6], false)
4829 .unwrap();
4830
4831 let gn = GroupNorm::<f32>::new(2, 4, 1e-5, true).unwrap();
4832 let output = gn.forward(&input).unwrap();
4833 assert_eq!(output.shape(), &[2, 4, 6]);
4834 }
4835
4836 #[test]
4837 fn test_group_norm_channel_mismatch() {
4838 let input =
4839 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 24]), vec![2, 3, 4], false)
4840 .unwrap();
4841
4842 let gn = GroupNorm::<f32>::new(2, 4, 1e-5, true).unwrap();
4843 assert!(gn.forward(&input).is_err());
4844 }
4845
4846 #[test]
4847 fn test_group_norm_has_grad_fn() {
4848 let input =
4849 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 8]), vec![1, 4, 2], true)
4850 .unwrap();
4851
4852 let gn = GroupNorm::<f32>::new(2, 4, 1e-5, true).unwrap();
4853 let output = gn.forward(&input).unwrap();
4854 assert!(output.grad_fn().is_some());
4855 assert_eq!(output.grad_fn().unwrap().name(), "GroupNormBackward");
4856 }
4857
4858 #[test]
4859 fn test_group_norm_backward_gradient_check() -> FerrotorchResult<()> {
4860 let h = 1e-7;
4861 let input_data = vec![1.0f64, -0.5, 2.0, 0.3, -1.0, 0.7, 1.5, -0.2];
4863 let gn = GroupNorm::<f64>::new(2, 4, 1e-5, true)?;
4864
4865 let input = leaf(&input_data, &[1, 4, 2], true);
4866 let output = gn.forward(&input)?;
4867 let out_data = output.data()?.to_vec();
4868 let total: f64 = out_data.iter().sum();
4869
4870 let sum_gf = Arc::new(SumBackwardHelper {
4871 input: output.clone(),
4872 });
4873 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf)?;
4874 loss.backward()?;
4875
4876 let analytic_grad = input.grad().unwrap().unwrap();
4877 let analytic = analytic_grad.data()?.to_vec();
4878
4879 for i in 0..8 {
4880 let mut data_plus = input_data.clone();
4881 data_plus[i] += h;
4882 let inp_plus = leaf(&data_plus, &[1, 4, 2], false);
4883 let out_plus = no_grad(|| gn.forward(&inp_plus)).unwrap();
4884 let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
4885
4886 let mut data_minus = input_data.clone();
4887 data_minus[i] -= h;
4888 let inp_minus = leaf(&data_minus, &[1, 4, 2], false);
4889 let out_minus = no_grad(|| gn.forward(&inp_minus)).unwrap();
4890 let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
4891
4892 let numerical = (sum_plus - sum_minus) / (2.0 * h);
4893 assert!(
4894 (numerical - analytic[i]).abs() < 1e-4,
4895 "GroupNorm grad[{i}]: numerical={numerical}, analytic={}",
4896 analytic[i]
4897 );
4898 }
4899
4900 Ok(())
4901 }
4902
4903 #[test]
4904 fn test_group_norm_named_parameters() {
4905 let gn = GroupNorm::<f32>::new(2, 8, 1e-5, true).unwrap();
4906 let named = gn.named_parameters();
4907 assert_eq!(named.len(), 2);
4908 assert_eq!(named[0].0, "weight");
4909 assert_eq!(named[1].0, "bias");
4910 }
4911
4912 #[test]
4917 fn test_rms_norm_parameter_shapes() {
4918 let rn = RMSNorm::<f32>::new(vec![8], 1e-5).unwrap();
4919 let params = rn.parameters();
4920 assert_eq!(params.len(), 1);
4921 assert_eq!(params[0].shape(), &[8]); }
4923
4924 #[test]
4925 fn test_rms_norm_forward_scale() {
4926 let data: Vec<f32> = vec![
4928 1.0, 2.0, 3.0, 4.0, -1.0, 0.5, 2.0, -3.0, ];
4931 let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 4], false).unwrap();
4932
4933 let rn = RMSNorm::<f32>::new(vec![4], 1e-5).unwrap();
4934 let output = rn.forward(&input).unwrap();
4935 let out_data = output.data().unwrap();
4936
4937 for row in 0..2 {
4938 let start = row * 4;
4939 let end = start + 4;
4940 let row_data = &out_data[start..end];
4941
4942 let mean_sq: f32 = row_data.iter().map(|x| x * x).sum::<f32>() / 4.0;
4943 let rms = mean_sq.sqrt();
4944
4945 assert!(
4946 (rms - 1.0).abs() < 0.05,
4947 "row {row} RMS = {rms}, expected ~1"
4948 );
4949 }
4950 }
4951
4952 #[test]
4953 fn test_rms_norm_forward_shape_preserved() {
4954 let input =
4955 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 24]), vec![2, 3, 4], false)
4956 .unwrap();
4957
4958 let rn = RMSNorm::<f32>::new(vec![4], 1e-5).unwrap();
4959 let output = rn.forward(&input).unwrap();
4960 assert_eq!(output.shape(), &[2, 3, 4]);
4961 }
4962
4963 #[test]
4964 fn test_rms_norm_empty_normalized_shape() {
4965 assert!(RMSNorm::<f32>::new(vec![], 1e-5).is_err());
4966 }
4967
4968 #[test]
4969 fn test_rms_norm_has_grad_fn() {
4970 let input = Tensor::<f32>::from_storage(
4971 TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
4972 vec![1, 4],
4973 true,
4974 )
4975 .unwrap();
4976
4977 let rn = RMSNorm::<f32>::new(vec![4], 1e-5).unwrap();
4978 let output = rn.forward(&input).unwrap();
4979 assert!(output.grad_fn().is_some());
4980 assert_eq!(output.grad_fn().unwrap().name(), "RMSNormBackward");
4981 }
4982
4983 #[test]
4984 fn test_rms_norm_no_grad_fn_in_no_grad_context() {
4985 let input = Tensor::<f32>::from_storage(
4986 TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
4987 vec![1, 4],
4988 true,
4989 )
4990 .unwrap();
4991
4992 let rn = RMSNorm::<f32>::new(vec![4], 1e-5).unwrap();
4993 let output = no_grad(|| rn.forward(&input)).unwrap();
4994 assert!(output.grad_fn().is_none());
4995 }
4996
4997 #[test]
4998 fn test_rms_norm_backward_gradient_check() -> FerrotorchResult<()> {
4999 let h = 1e-7;
5000 let hidden = 4;
5001 let input_data = vec![1.0f64, -0.5, 2.0, 0.3];
5002
5003 let rn = RMSNorm::<f64>::new(vec![hidden], 1e-5)?;
5004
5005 let input = leaf(&input_data, &[1, hidden], true);
5006 let output = rn.forward(&input)?;
5007 let out_data = output.data()?.to_vec();
5008 let total: f64 = out_data.iter().sum();
5009
5010 let sum_gf = Arc::new(SumBackwardHelper {
5011 input: output.clone(),
5012 });
5013 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf)?;
5014 loss.backward()?;
5015
5016 let analytic_grad = input.grad().unwrap().unwrap();
5017 let analytic = analytic_grad.data()?.to_vec();
5018
5019 for i in 0..hidden {
5020 let mut data_plus = input_data.clone();
5021 data_plus[i] += h;
5022 let inp_plus = leaf(&data_plus, &[1, hidden], false);
5023 let out_plus = no_grad(|| rn.forward(&inp_plus)).unwrap();
5024 let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
5025
5026 let mut data_minus = input_data.clone();
5027 data_minus[i] -= h;
5028 let inp_minus = leaf(&data_minus, &[1, hidden], false);
5029 let out_minus = no_grad(|| rn.forward(&inp_minus)).unwrap();
5030 let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
5031
5032 let numerical = (sum_plus - sum_minus) / (2.0 * h);
5033 assert!(
5034 (numerical - analytic[i]).abs() < 1e-4,
5035 "RMSNorm grad[{i}]: numerical={numerical}, analytic={}",
5036 analytic[i]
5037 );
5038 }
5039
5040 Ok(())
5041 }
5042
5043 #[test]
5044 fn test_rms_norm_named_parameters() {
5045 let rn = RMSNorm::<f32>::new(vec![16], 1e-5).unwrap();
5046 let named = rn.named_parameters();
5047 assert_eq!(named.len(), 1);
5048 assert_eq!(named[0].0, "weight");
5049 }
5050
5051 #[test]
5052 fn test_rms_norm_train_eval() {
5053 let mut rn = RMSNorm::<f32>::new(vec![8], 1e-5).unwrap();
5054 assert!(rn.is_training());
5055 rn.eval();
5056 assert!(!rn.is_training());
5057 rn.train();
5058 assert!(rn.is_training());
5059 }
5060
5061 #[test]
5066 fn test_batch_norm_2d_output_shape() {
5067 let bn = BatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5068 let input = Tensor::from_storage(
5070 TensorStorage::cpu(vec![1.0f32; 2 * 3 * 4 * 4]),
5071 vec![2, 3, 4, 4],
5072 false,
5073 )
5074 .unwrap();
5075
5076 let output = bn.forward(&input).unwrap();
5077 assert_eq!(output.shape(), &[2, 3, 4, 4]);
5078 }
5079
5080 #[test]
5081 fn test_batch_norm_2d_rejects_non_4d() {
5082 let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5083 let input =
5085 Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 24]), vec![2, 4, 3], false)
5086 .unwrap();
5087 assert!(bn.forward(&input).is_err());
5088 }
5089
5090 #[test]
5091 fn test_batch_norm_2d_channel_mismatch() {
5092 let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5093 let input = Tensor::from_storage(
5094 TensorStorage::cpu(vec![1.0f32; 2 * 3 * 2 * 2]),
5095 vec![2, 3, 2, 2],
5096 false,
5097 )
5098 .unwrap();
5099 assert!(bn.forward(&input).is_err());
5100 }
5101
5102 #[test]
5103 fn test_batch_norm_2d_zero_features() {
5104 assert!(BatchNorm2d::<f32>::new(0, 1e-5, 0.1, true).is_err());
5105 }
5106
5107 #[test]
5108 fn test_batch_norm_2d_training_normalizes() {
5109 let channels = 2;
5112 let b = 2;
5113 let h = 3;
5114 let w = 3;
5115 let spatial = h * w;
5116 let mut data = Vec::new();
5118 for bi in 0..b {
5119 for c in 0..channels {
5120 let offset = c as f32 * 100.0;
5121 for s in 0..spatial {
5122 data.push(offset + (bi * spatial + s) as f32 + 1.0);
5123 }
5124 }
5125 }
5126 let input =
5127 Tensor::from_storage(TensorStorage::cpu(data), vec![b, channels, h, w], false).unwrap();
5128
5129 let bn = BatchNorm2d::<f32>::new(channels, 1e-5, 0.1, true).unwrap();
5130 let output = bn.forward(&input).unwrap();
5131 let out_data = output.data().unwrap();
5132
5133 for c in 0..channels {
5134 let mut vals = Vec::new();
5136 for bi in 0..b {
5137 let base = bi * channels * spatial + c * spatial;
5138 for s in 0..spatial {
5139 vals.push(out_data[base + s]);
5140 }
5141 }
5142 let n = vals.len() as f32;
5143 let mean: f32 = vals.iter().sum::<f32>() / n;
5144 let var: f32 = vals.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / n;
5145
5146 assert!(mean.abs() < 1e-4, "channel {c}: mean = {mean}, expected ~0");
5147 assert!(
5148 (var - 1.0).abs() < 0.1,
5149 "channel {c}: var = {var}, expected ~1"
5150 );
5151 }
5152 }
5153
5154 #[test]
5155 fn test_batch_norm_2d_eval_uses_running_stats() {
5156 let channels = 2;
5157 let b = 4;
5158 let h = 2;
5159 let w = 2;
5160 let spatial = h * w;
5161
5162 let bn = BatchNorm2d::<f64>::new(channels, 1e-5, 0.1, true).unwrap();
5164
5165 let mut data = vec![0.0f64; b * channels * spatial];
5167 for bi in 0..b {
5168 for c in 0..channels {
5169 let base = bi * channels * spatial + c * spatial;
5170 for s in 0..spatial {
5171 data[base + s] = (c as f64) * 10.0 + (bi * spatial + s) as f64;
5172 }
5173 }
5174 }
5175 let input = Tensor::from_storage(
5176 TensorStorage::cpu(data.clone()),
5177 vec![b, channels, h, w],
5178 false,
5179 )
5180 .unwrap();
5181
5182 let _ = bn.forward(&input).unwrap();
5184 let rm_after_train = bn.running_mean();
5185 let rv_after_train = bn.running_var();
5186
5187 assert!(
5189 rm_after_train[0].abs() > 1e-6 || rm_after_train[1].abs() > 1e-6,
5190 "running_mean should have been updated"
5191 );
5192
5193 *bn.training.lock().unwrap() = false;
5196
5197 let output_eval = bn.forward(&input).unwrap();
5198 let eval_data = output_eval.data().unwrap();
5199
5200 for c in 0..channels {
5205 let expected_mean = rm_after_train[c];
5206 let expected_var = rv_after_train[c];
5207 let inv_std = 1.0 / (expected_var + 1e-5).sqrt();
5208
5209 for bi in 0..b {
5210 let base = bi * channels * spatial + c * spatial;
5211 for s in 0..spatial {
5212 let x = (c as f64) * 10.0 + (bi * spatial + s) as f64;
5213 let expected = (x - expected_mean) * inv_std;
5214 let actual = eval_data[base + s];
5216 assert!(
5217 (actual - expected).abs() < 1e-6,
5218 "eval output mismatch at b={bi}, c={c}, s={s}: actual={actual}, expected={expected}"
5219 );
5220 }
5221 }
5222 }
5223 }
5224
5225 #[test]
5226 fn test_batch_norm_2d_running_stats_update() {
5227 let channels = 2;
5228 let bn = BatchNorm2d::<f32>::new(channels, 1e-5, 0.1, true).unwrap();
5229
5230 assert_eq!(bn.running_mean(), vec![0.0, 0.0]);
5232 assert_eq!(bn.running_var(), vec![1.0, 1.0]);
5233 assert_eq!(bn.num_batches_tracked(), 0);
5234
5235 let input = Tensor::from_storage(
5237 TensorStorage::cpu(vec![1.0f32; 2 * 2 * 2 * 2]),
5238 vec![2, 2, 2, 2],
5239 false,
5240 )
5241 .unwrap();
5242 let _ = bn.forward(&input).unwrap();
5243 assert_eq!(bn.num_batches_tracked(), 1);
5244
5245 let rm = bn.running_mean();
5246 let rv = bn.running_var();
5247 assert!(
5249 (rm[0] - 0.1).abs() < 1e-5,
5250 "running_mean[0] = {}, expected 0.1",
5251 rm[0]
5252 );
5253 assert!(
5256 (rv[0] - 0.9).abs() < 1e-5,
5257 "running_var[0] = {}, expected 0.9",
5258 rv[0]
5259 );
5260
5261 let _ = bn.forward(&input).unwrap();
5263 assert_eq!(bn.num_batches_tracked(), 2);
5264 }
5265
5266 #[test]
5267 fn test_batch_norm_2d_affine_parameters() {
5268 let bn = BatchNorm2d::<f32>::new(8, 1e-5, 0.1, true).unwrap();
5269 let params = bn.parameters();
5270 assert_eq!(params.len(), 2);
5271 assert_eq!(params[0].shape(), &[8]); assert_eq!(params[1].shape(), &[8]); let named = bn.named_parameters();
5275 assert_eq!(named.len(), 2);
5276 assert_eq!(named[0].0, "weight");
5277 assert_eq!(named[1].0, "bias");
5278
5279 let weight_data = params[0].data().unwrap();
5281 let bias_data = params[1].data().unwrap();
5282 assert!(weight_data.iter().all(|&x| (x - 1.0).abs() < 1e-7));
5283 assert!(bias_data.iter().all(|&x| x.abs() < 1e-7));
5284 }
5285
5286 #[test]
5287 fn test_batch_norm_2d_no_affine_no_params() {
5288 let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, false).unwrap();
5289 assert_eq!(bn.parameters().len(), 0);
5290 assert_eq!(bn.named_parameters().len(), 0);
5291 assert!(bn.weight.is_none());
5292 assert!(bn.bias.is_none());
5293 }
5294
5295 #[test]
5296 fn test_batch_norm_2d_train_eval_toggle() {
5297 let mut bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5298 assert!(bn.is_training());
5299 bn.eval();
5300 assert!(!bn.is_training());
5301 bn.train();
5302 assert!(bn.is_training());
5303 }
5304
5305 #[test]
5306 fn test_batch_norm_2d_has_grad_fn() {
5307 let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
5308 let input = Tensor::from_storage(
5309 TensorStorage::cpu(vec![1.0f32; 2 * 2 * 3 * 3]),
5310 vec![2, 2, 3, 3],
5311 true,
5312 )
5313 .unwrap();
5314
5315 let output = bn.forward(&input).unwrap();
5316 assert!(output.grad_fn().is_some());
5317 assert_eq!(output.grad_fn().unwrap().name(), "BatchNorm2dBackward");
5318 }
5319
5320 #[test]
5321 fn test_batch_norm_2d_no_grad_fn_in_no_grad_context() {
5322 let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
5323 let input = Tensor::from_storage(
5324 TensorStorage::cpu(vec![1.0f32; 2 * 2 * 3 * 3]),
5325 vec![2, 2, 3, 3],
5326 true,
5327 )
5328 .unwrap();
5329
5330 let output = no_grad(|| bn.forward(&input)).unwrap();
5331 assert!(output.grad_fn().is_none());
5332 }
5333
5334 #[test]
5335 fn test_batch_norm_2d_backward_gradient_check() -> FerrotorchResult<()> {
5336 let h_eps = 1e-7;
5337 let channels = 2;
5338 let b = 2;
5339 let height = 2;
5340 let width = 2;
5341 let spatial = height * width;
5342 let numel = b * channels * spatial;
5343
5344 let input_data: Vec<f64> = (0..numel).map(|i| (i as f64) * 0.3 - 1.0).collect();
5346
5347 let bn = BatchNorm2d::<f64>::new(channels, 1e-5, 0.1, true)?;
5348
5349 let input = leaf(&input_data, &[b, channels, height, width], true);
5350 let output = bn.forward(&input)?;
5351 let out_data = output.data()?.to_vec();
5352 let total: f64 = out_data.iter().sum();
5353
5354 let sum_gf = Arc::new(SumBackwardHelper {
5355 input: output.clone(),
5356 });
5357 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf)?;
5358 loss.backward()?;
5359
5360 let analytic_grad = input.grad().unwrap().unwrap();
5361 let analytic = analytic_grad.data()?.to_vec();
5362
5363 for i in 0..numel {
5367 let mut data_plus = input_data.clone();
5369 data_plus[i] += h_eps;
5370 let inp_plus = leaf(&data_plus, &[b, channels, height, width], false);
5371 let bn_plus = BatchNorm2d::<f64>::new(channels, 1e-5, 0.1, true)?;
5372 let out_plus = no_grad(|| bn_plus.forward(&inp_plus)).unwrap();
5373 let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
5374
5375 let mut data_minus = input_data.clone();
5377 data_minus[i] -= h_eps;
5378 let inp_minus = leaf(&data_minus, &[b, channels, height, width], false);
5379 let bn_minus = BatchNorm2d::<f64>::new(channels, 1e-5, 0.1, true)?;
5380 let out_minus = no_grad(|| bn_minus.forward(&inp_minus)).unwrap();
5381 let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
5382
5383 let numerical = (sum_plus - sum_minus) / (2.0 * h_eps);
5384 assert!(
5385 (numerical - analytic[i]).abs() < 1e-4,
5386 "BatchNorm2d grad[{i}]: numerical={numerical}, analytic={}",
5387 analytic[i]
5388 );
5389 }
5390
5391 Ok(())
5392 }
5393
5394 #[test]
5395 fn test_batch_norm_2d_no_affine_forward() {
5396 let channels = 2;
5398 let b = 2;
5399 let h = 2;
5400 let w = 2;
5401 let spatial = h * w;
5402
5403 let mut data = Vec::new();
5404 for bi in 0..b {
5405 for c in 0..channels {
5406 for s in 0..spatial {
5407 data.push((c as f32) * 5.0 + (bi * spatial + s) as f32);
5408 }
5409 }
5410 }
5411
5412 let input =
5413 Tensor::from_storage(TensorStorage::cpu(data), vec![b, channels, h, w], false).unwrap();
5414
5415 let bn = BatchNorm2d::<f32>::new(channels, 1e-5, 0.1, false).unwrap();
5416 let output = bn.forward(&input).unwrap();
5417 let out_data = output.data().unwrap();
5418
5419 for c in 0..channels {
5420 let mut vals = Vec::new();
5421 for bi in 0..b {
5422 let base = bi * channels * spatial + c * spatial;
5423 for s in 0..spatial {
5424 vals.push(out_data[base + s]);
5425 }
5426 }
5427 let n = vals.len() as f32;
5428 let mean: f32 = vals.iter().sum::<f32>() / n;
5429 let var: f32 = vals.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / n;
5430
5431 assert!(mean.abs() < 1e-4, "no-affine channel {c}: mean = {mean}");
5432 assert!(
5433 (var - 1.0).abs() < 0.1,
5434 "no-affine channel {c}: var = {var}"
5435 );
5436 }
5437 }
5438
5439 #[test]
5444 fn test_layer_norm_is_send_sync() {
5445 fn assert_send_sync<T: Send + Sync>() {}
5446 assert_send_sync::<LayerNorm<f32>>();
5447 }
5448
5449 #[test]
5450 fn test_group_norm_is_send_sync() {
5451 fn assert_send_sync<T: Send + Sync>() {}
5452 assert_send_sync::<GroupNorm<f32>>();
5453 }
5454
5455 #[test]
5456 fn test_rms_norm_is_send_sync() {
5457 fn assert_send_sync<T: Send + Sync>() {}
5458 assert_send_sync::<RMSNorm<f32>>();
5459 }
5460
5461 #[test]
5462 fn test_batch_norm_2d_is_send_sync() {
5463 fn assert_send_sync<T: Send + Sync>() {}
5464 assert_send_sync::<BatchNorm2d<f32>>();
5465 }
5466
5467 fn one<T: Float>() -> T {
5473 <T as num_traits::One>::one()
5474 }
5475
5476 #[derive(Debug)]
5478 struct SumBackwardHelper<T: Float> {
5479 input: Tensor<T>,
5480 }
5481
5482 impl<T: Float> GradFn<T> for SumBackwardHelper<T> {
5483 fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
5484 let ones_data = vec![one::<T>(); self.input.numel()];
5485 let ones = Tensor::from_storage(
5486 TensorStorage::cpu(ones_data),
5487 self.input.shape().to_vec(),
5488 false,
5489 )?;
5490 Ok(vec![Some(ones)])
5491 }
5492
5493 fn inputs(&self) -> Vec<&Tensor<T>> {
5494 vec![&self.input]
5495 }
5496
5497 fn name(&self) -> &'static str {
5498 "SumBackwardHelper"
5499 }
5500 }
5501
5502 #[test]
5507 fn test_batchnorm1d_parameter_shapes() {
5508 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5509 let params = bn.parameters();
5510 assert_eq!(params.len(), 2);
5511 assert_eq!(params[0].shape(), &[3]); assert_eq!(params[1].shape(), &[3]); }
5514
5515 #[test]
5516 fn test_batchnorm1d_no_affine() {
5517 let bn = BatchNorm1d::<f32>::new(4, 1e-5, 0.1, false).unwrap();
5518 assert!(bn.parameters().is_empty());
5519 }
5520
5521 #[test]
5522 fn test_batchnorm1d_2d_input() {
5523 let bn = BatchNorm1d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
5525 let input = Tensor::from_storage(
5526 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
5527 vec![4, 2],
5528 false,
5529 )
5530 .unwrap();
5531 let output = bn.forward(&input).unwrap();
5532 assert_eq!(output.shape(), &[4, 2]);
5533 }
5534
5535 #[test]
5536 fn test_batchnorm1d_3d_input() {
5537 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5539 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
5540 let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 3, 4], false).unwrap();
5541 let output = bn.forward(&input).unwrap();
5542 assert_eq!(output.shape(), &[2, 3, 4]);
5543 }
5544
5545 #[test]
5546 fn test_batchnorm1d_wrong_dims() {
5547 let bn = BatchNorm1d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5549
5550 let input_1d = Tensor::from_storage(
5551 TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
5552 vec![4],
5553 false,
5554 )
5555 .unwrap();
5556 assert!(bn.forward(&input_1d).is_err());
5557
5558 let input_4d = Tensor::from_storage(
5559 TensorStorage::cpu(vec![0.0f32; 32]),
5560 vec![2, 4, 2, 2],
5561 false,
5562 )
5563 .unwrap();
5564 assert!(bn.forward(&input_4d).is_err());
5565 }
5566
5567 #[test]
5568 fn test_batchnorm1d_zero_features() {
5569 assert!(BatchNorm1d::<f32>::new(0, 1e-5, 0.1, true).is_err());
5570 }
5571
5572 #[test]
5573 fn test_batchnorm1d_channel_mismatch() {
5574 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5575 let input =
5576 Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 8]), vec![2, 4], false).unwrap();
5577 assert!(bn.forward(&input).is_err());
5578 }
5579
5580 #[test]
5581 fn test_batchnorm1d_training_normalizes() {
5582 let channels = 2;
5585 let bn = BatchNorm1d::<f64>::new(channels, 1e-5, 0.1, true).unwrap();
5586
5587 let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], false);
5589 let output = bn.forward(&input).unwrap();
5590 let data = output.data().unwrap();
5591
5592 let ch0: Vec<f64> = (0..4).map(|b| data[b * 2]).collect();
5594 let ch0_mean: f64 = ch0.iter().sum::<f64>() / 4.0;
5595 assert!(
5596 ch0_mean.abs() < 1e-5,
5597 "BatchNorm1d channel 0 mean should be ~0, got {}",
5598 ch0_mean
5599 );
5600
5601 let ch0_var: f64 = ch0.iter().map(|&x| (x - ch0_mean).powi(2)).sum::<f64>() / 4.0;
5603 assert!(
5604 (ch0_var - 1.0).abs() < 0.1,
5605 "BatchNorm1d channel 0 var should be ~1, got {}",
5606 ch0_var
5607 );
5608 }
5609
5610 #[test]
5611 fn test_batchnorm1d_running_stats_update() {
5612 let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5613 let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], false);
5614 let _ = bn.forward(&input).unwrap();
5615
5616 assert_eq!(bn.num_batches_tracked(), 1);
5617 let rm = bn.running_mean();
5618 let rv = bn.running_var();
5619 assert!(
5623 (rm[0] - 0.1 * 4.0).abs() < 1e-7,
5624 "running_mean[0]: expected {}, got {}",
5625 0.1 * 4.0,
5626 rm[0]
5627 );
5628 assert!(
5629 (rm[1] - 0.1 * 5.0).abs() < 1e-7,
5630 "running_mean[1]: expected {}, got {}",
5631 0.1 * 5.0,
5632 rm[1]
5633 );
5634
5635 assert!(rv[0] > 0.0);
5637 assert!(rv[1] > 0.0);
5638 }
5639
5640 #[test]
5641 fn test_batchnorm1d_eval_mode() {
5642 let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5643
5644 let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], false);
5646 let _ = bn.forward(&input).unwrap();
5647
5648 *bn.training.lock().unwrap() = false;
5651
5652 let eval_out = bn.forward(&input).unwrap();
5653 assert_eq!(eval_out.shape(), &[4, 2]);
5654 }
5655
5656 #[test]
5657 fn test_batchnorm1d_no_affine_normalizes() {
5658 let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, false).unwrap();
5659 let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], false);
5660 let output = bn.forward(&input).unwrap();
5661 assert_eq!(output.shape(), &[4, 2]);
5662 }
5663
5664 #[test]
5665 fn test_batchnorm1d_3d_normalizes() {
5666 let channels = 2;
5668 let bn = BatchNorm1d::<f64>::new(channels, 1e-5, 0.1, true).unwrap();
5669 let data: Vec<f64> = (0..12).map(|i| i as f64).collect();
5670 let input = leaf(&data, &[2, 2, 3], false);
5671 let output = bn.forward(&input).unwrap();
5672 assert_eq!(output.shape(), &[2, 2, 3]);
5673
5674 let out_data = output.data().unwrap();
5676
5677 let ch0: Vec<f64> = vec![
5680 out_data[0],
5681 out_data[1],
5682 out_data[2],
5683 out_data[6],
5684 out_data[7],
5685 out_data[8],
5686 ];
5687 let ch0_mean: f64 = ch0.iter().sum::<f64>() / 6.0;
5688 assert!(
5689 ch0_mean.abs() < 1e-5,
5690 "BatchNorm1d 3D channel 0 mean should be ~0, got {}",
5691 ch0_mean
5692 );
5693 }
5694
5695 #[test]
5696 fn test_batchnorm1d_train_eval_toggle() {
5697 let bn = BatchNorm1d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5698 assert!(bn.is_training());
5699 *bn.training.lock().unwrap() = false;
5700 assert!(!bn.is_training());
5701 *bn.training.lock().unwrap() = true;
5702 assert!(bn.is_training());
5703 }
5704
5705 #[test]
5706 fn test_batchnorm1d_grad_fn_name() {
5707 let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5708 let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], true);
5709 let output = bn.forward(&input).unwrap();
5710 assert_eq!(output.grad_fn().unwrap().name(), "BatchNorm1dBackward");
5711 }
5712
5713 #[test]
5714 fn test_batchnorm1d_backward_grad_shapes() {
5715 use ferrotorch_core::autograd::graph::backward;
5716
5717 let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5718 let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], true);
5719 let output = bn.forward(&input).unwrap();
5720
5721 let out_data = output.data().unwrap().to_vec();
5723 let total: f64 = out_data.iter().sum();
5724 let sum_gf = Arc::new(SumBackwardHelper {
5725 input: output.clone(),
5726 });
5727 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5728 backward(&loss).unwrap();
5729
5730 let grad = input.grad().unwrap().unwrap();
5731 assert_eq!(grad.shape(), &[4, 2]);
5732 }
5733
5734 #[test]
5743 fn test_batchnorm2d_nonaffine_backward_runs_grad_input_only() {
5744 use ferrotorch_core::autograd::graph::backward;
5745
5746 let bn = BatchNorm2d::<f64>::new(4, 1e-5, 0.1, false).unwrap();
5747 assert!(bn.weight.is_none() && bn.bias.is_none());
5748 let n = 3 * 4 * 2 * 3;
5749 let data: Vec<f64> = (0..n).map(|k| ((k % 19) as f64) * 0.11 - 1.0).collect();
5750 let input = leaf(&data, &[3, 4, 2, 3], true);
5751 let output = bn.forward(&input).unwrap();
5752
5753 assert_eq!(output.grad_fn().unwrap().inputs().len(), 1);
5756
5757 let total: f64 = output.data().unwrap().iter().sum();
5758 let sum_gf = Arc::new(SumBackwardHelper {
5759 input: output.clone(),
5760 });
5761 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5762 backward(&loss).unwrap();
5763
5764 let grad = input.grad().unwrap().expect("grad_input populated");
5765 assert_eq!(grad.shape(), &[3, 4, 2, 3]);
5766 }
5767
5768 #[test]
5769 fn test_batchnorm1d_nonaffine_backward_runs_grad_input_only() {
5770 use ferrotorch_core::autograd::graph::backward;
5771
5772 let bn = BatchNorm1d::<f64>::new(3, 1e-5, 0.1, false).unwrap();
5773 let n = 4 * 3 * 5;
5774 let data: Vec<f64> = (0..n).map(|k| ((k % 17) as f64) * 0.13 - 0.9).collect();
5775 let input = leaf(&data, &[4, 3, 5], true);
5776 let output = bn.forward(&input).unwrap();
5777 assert_eq!(output.grad_fn().unwrap().inputs().len(), 1);
5778
5779 let total: f64 = output.data().unwrap().iter().sum();
5780 let sum_gf = Arc::new(SumBackwardHelper {
5781 input: output.clone(),
5782 });
5783 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5784 backward(&loss).unwrap();
5785
5786 let grad = input.grad().unwrap().expect("grad_input populated");
5787 assert_eq!(grad.shape(), &[4, 3, 5]);
5788 }
5789
5790 #[test]
5791 fn test_batchnorm3d_nonaffine_backward_runs_grad_input_only() {
5792 use ferrotorch_core::autograd::graph::backward;
5793
5794 let bn = BatchNorm3d::<f64>::new(3, 1e-5, 0.1, false).unwrap();
5795 let n = 2 * 3 * 2 * 2 * 2;
5796 let data: Vec<f64> = (0..n).map(|k| ((k % 13) as f64) * 0.17 - 1.0).collect();
5797 let input = leaf(&data, &[2, 3, 2, 2, 2], true);
5798 let output = bn.forward(&input).unwrap();
5799 assert_eq!(output.grad_fn().unwrap().inputs().len(), 1);
5800
5801 let total: f64 = output.data().unwrap().iter().sum();
5802 let sum_gf = Arc::new(SumBackwardHelper {
5803 input: output.clone(),
5804 });
5805 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5806 backward(&loss).unwrap();
5807
5808 let grad = input.grad().unwrap().expect("grad_input populated");
5809 assert_eq!(grad.shape(), &[2, 3, 2, 2, 2]);
5810 }
5811
5812 #[test]
5813 fn test_batchnorm2d_affine_backward_still_returns_three_grads() {
5814 use ferrotorch_core::autograd::graph::backward;
5815
5816 let bn = BatchNorm2d::<f64>::new(4, 1e-5, 0.1, true).unwrap();
5820 let n = 3 * 4 * 2 * 3;
5821 let data: Vec<f64> = (0..n).map(|k| ((k % 19) as f64) * 0.11 - 1.0).collect();
5822 let input = leaf(&data, &[3, 4, 2, 3], true);
5823 let output = bn.forward(&input).unwrap();
5824 assert_eq!(output.grad_fn().unwrap().inputs().len(), 3);
5825
5826 let total: f64 = output.data().unwrap().iter().sum();
5827 let sum_gf = Arc::new(SumBackwardHelper {
5828 input: output.clone(),
5829 });
5830 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5831 backward(&loss).unwrap();
5832
5833 assert!(input.grad().unwrap().is_some());
5834 assert!(
5835 bn.weight
5836 .as_ref()
5837 .unwrap()
5838 .tensor()
5839 .grad()
5840 .unwrap()
5841 .is_some()
5842 );
5843 assert!(bn.bias.as_ref().unwrap().tensor().grad().unwrap().is_some());
5844 }
5845
5846 #[test]
5847 fn test_batchnorm1d_backward_numerical() {
5848 use ferrotorch_core::autograd::graph::backward;
5849
5850 let channels = 2;
5852 let eps_val = 1e-5;
5853 let input_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
5854 let shape = [4usize, 2];
5855
5856 let bn = BatchNorm1d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
5858 let input = leaf(&input_data, &shape, true);
5859 let output = bn.forward(&input).unwrap();
5860 let out_data = output.data().unwrap().to_vec();
5861 let total: f64 = out_data.iter().sum();
5862 let sum_gf = Arc::new(SumBackwardHelper {
5863 input: output.clone(),
5864 });
5865 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5866 backward(&loss).unwrap();
5867 let analytic_grad = input.grad().unwrap().unwrap().data_vec().unwrap();
5868
5869 let h = 1e-5;
5871 let mut numerical_grad = vec![0.0f64; input_data.len()];
5872 for i in 0..input_data.len() {
5873 let mut data_plus = input_data.clone();
5874 data_plus[i] += h;
5875 let bn_plus = BatchNorm1d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
5876 let input_plus = leaf(&data_plus, &shape, false);
5877 let out_plus = no_grad(|| bn_plus.forward(&input_plus)).unwrap();
5878 let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
5879
5880 let mut data_minus = input_data.clone();
5881 data_minus[i] -= h;
5882 let bn_minus = BatchNorm1d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
5883 let input_minus = leaf(&data_minus, &shape, false);
5884 let out_minus = no_grad(|| bn_minus.forward(&input_minus)).unwrap();
5885 let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
5886
5887 numerical_grad[i] = (sum_plus - sum_minus) / (2.0 * h);
5888 }
5889
5890 for i in 0..input_data.len() {
5891 assert!(
5892 (analytic_grad[i] - numerical_grad[i]).abs() < 1e-3,
5893 "BatchNorm1d grad[{}]: numerical={}, analytic={}",
5894 i,
5895 numerical_grad[i],
5896 analytic_grad[i]
5897 );
5898 }
5899 }
5900
5901 #[test]
5902 fn test_batchnorm1d_empty_batch() {
5903 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5904 let input = Tensor::from_storage(TensorStorage::cpu(vec![]), vec![0, 3], false).unwrap();
5905 let output = bn.forward(&input).unwrap();
5906 assert_eq!(output.shape(), &[0, 3]);
5907 assert_eq!(output.numel(), 0);
5908 }
5909
5910 #[test]
5911 fn test_batchnorm1d_is_send_sync() {
5912 fn assert_send_sync<T: Send + Sync>() {}
5913 assert_send_sync::<BatchNorm1d<f32>>();
5914 }
5915
5916 #[test]
5921 fn test_batchnorm3d_output_shape() {
5922 let bn = BatchNorm3d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5923 let input = Tensor::from_storage(
5924 TensorStorage::cpu(vec![1.0f32; 2 * 3 * 2 * 2 * 2]),
5925 vec![2, 3, 2, 2, 2],
5926 false,
5927 )
5928 .unwrap();
5929 let output = bn.forward(&input).unwrap();
5930 assert_eq!(output.shape(), &[2, 3, 2, 2, 2]);
5931 }
5932
5933 #[test]
5934 fn test_batchnorm3d_rejects_non_5d() {
5935 let bn = BatchNorm3d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5936 let input =
5937 Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 24]), vec![2, 3, 4], false)
5938 .unwrap();
5939 assert!(bn.forward(&input).is_err());
5940 }
5941
5942 #[test]
5943 fn test_batchnorm3d_channel_mismatch() {
5944 let bn = BatchNorm3d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5945 let input = Tensor::from_storage(
5946 TensorStorage::cpu(vec![1.0f32; 2 * 4 * 2 * 2 * 2]),
5947 vec![2, 4, 2, 2, 2],
5948 false,
5949 )
5950 .unwrap();
5951 assert!(bn.forward(&input).is_err());
5952 }
5953
5954 #[test]
5955 fn test_batchnorm3d_zero_features_rejected() {
5956 assert!(BatchNorm3d::<f32>::new(0, 1e-5, 0.1, true).is_err());
5957 }
5958
5959 #[test]
5960 fn test_batchnorm3d_training_normalizes() {
5961 let channels = 2;
5964 let bn = BatchNorm3d::<f64>::new(channels, 1e-5, 0.1, true).unwrap();
5965 let mut data = Vec::with_capacity(2 * 2 * 2 * 2 * 2);
5966 for i in 0..(2 * 2 * 2 * 2 * 2) {
5967 data.push(i as f64);
5968 }
5969 let input = leaf(&data, &[2, 2, 2, 2, 2], false);
5970 let output = bn.forward(&input).unwrap();
5971 let out_data = output.data().unwrap();
5972
5973 let spatial = 2 * 2 * 2;
5974 let batch = 2;
5975 for c in 0..channels {
5976 let mut sum = 0.0;
5977 for b in 0..batch {
5978 let base = b * channels * spatial + c * spatial;
5979 for s in 0..spatial {
5980 sum += out_data[base + s];
5981 }
5982 }
5983 let mean = sum / (batch * spatial) as f64;
5984 assert!(mean.abs() < 1e-5, "channel {c} mean = {mean}, expected ~0");
5985 }
5986 }
5987
5988 #[test]
5989 fn test_batchnorm3d_running_stats_updated() {
5990 let bn = BatchNorm3d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5991 let data: Vec<f64> = (0..32).map(|i| i as f64).collect();
5992 let input = leaf(&data, &[2, 2, 2, 2, 2], false);
5993 let _ = bn.forward(&input).unwrap();
5994
5995 assert_eq!(bn.num_batches_tracked(), 1);
5996 let rm = bn.running_mean();
5997 assert!(
5998 rm[0] != 0.0 || rm[1] != 0.0,
5999 "running mean should be updated"
6000 );
6001 }
6002
6003 #[test]
6004 fn test_batchnorm3d_eval_uses_running_stats() {
6005 let mut bn = BatchNorm3d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
6006 let data: Vec<f64> = (0..32).map(|i| i as f64).collect();
6008 let input = leaf(&data, &[2, 2, 2, 2, 2], false);
6009 let _ = bn.forward(&input).unwrap();
6010
6011 bn.eval();
6012 let output = bn.forward(&input).unwrap();
6014 assert_eq!(output.shape(), &[2, 2, 2, 2, 2]);
6015 }
6016
6017 #[test]
6018 fn test_batchnorm3d_parameters() {
6019 let bn = BatchNorm3d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6020 let params = bn.parameters();
6021 assert_eq!(params.len(), 2);
6022 assert_eq!(params[0].shape(), &[4]); assert_eq!(params[1].shape(), &[4]); }
6025
6026 #[test]
6027 fn test_batchnorm3d_no_affine_no_params() {
6028 let bn = BatchNorm3d::<f32>::new(4, 1e-5, 0.1, false).unwrap();
6029 assert!(bn.parameters().is_empty());
6030 }
6031
6032 #[test]
6033 fn test_batchnorm3d_backward_grad_shapes() {
6034 use ferrotorch_core::autograd::graph::backward;
6035
6036 let bn = BatchNorm3d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
6037 let data: Vec<f64> = (0..32).map(|i| i as f64).collect();
6038 let input = leaf(&data, &[2, 2, 2, 2, 2], true);
6039 let output = bn.forward(&input).unwrap();
6040
6041 let out_data = output.data().unwrap().to_vec();
6042 let total: f64 = out_data.iter().sum();
6043 let sum_gf = Arc::new(SumBackwardHelper {
6044 input: output.clone(),
6045 });
6046 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
6047 backward(&loss).unwrap();
6048
6049 let grad = input.grad().unwrap().unwrap();
6050 assert_eq!(grad.shape(), &[2, 2, 2, 2, 2]);
6051 }
6052
6053 #[test]
6054 fn test_batchnorm3d_backward_numerical() {
6055 use ferrotorch_core::autograd::graph::backward;
6056
6057 let channels = 2;
6058 let eps_val = 1e-5;
6059 let data: Vec<f64> = (0..32).map(|i| i as f64 * 0.1).collect();
6060 let shape = [2usize, 2, 2, 2, 2];
6061
6062 let bn = BatchNorm3d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
6063 let input = leaf(&data, &shape, true);
6064 let output = bn.forward(&input).unwrap();
6065 let out_data = output.data().unwrap().to_vec();
6066 let total: f64 = out_data.iter().sum();
6067 let sum_gf = Arc::new(SumBackwardHelper {
6068 input: output.clone(),
6069 });
6070 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
6071 backward(&loss).unwrap();
6072 let analytic_grad = input.grad().unwrap().unwrap().data_vec().unwrap();
6073
6074 let h = 1e-5;
6075 for i in 0..data.len() {
6076 let mut data_plus = data.clone();
6077 data_plus[i] += h;
6078 let bn_plus = BatchNorm3d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
6079 let input_plus = leaf(&data_plus, &shape, false);
6080 let out_plus = no_grad(|| bn_plus.forward(&input_plus)).unwrap();
6081 let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
6082
6083 let mut data_minus = data.clone();
6084 data_minus[i] -= h;
6085 let bn_minus = BatchNorm3d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
6086 let input_minus = leaf(&data_minus, &shape, false);
6087 let out_minus = no_grad(|| bn_minus.forward(&input_minus)).unwrap();
6088 let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
6089
6090 let numerical = (sum_plus - sum_minus) / (2.0 * h);
6091 assert!(
6092 (analytic_grad[i] - numerical).abs() < 1e-3,
6093 "BatchNorm3d grad[{}]: numerical={}, analytic={}",
6094 i,
6095 numerical,
6096 analytic_grad[i]
6097 );
6098 }
6099 }
6100
6101 #[test]
6102 fn test_batchnorm3d_is_send_sync() {
6103 fn assert_send_sync<T: Send + Sync>() {}
6104 assert_send_sync::<BatchNorm3d<f32>>();
6105 }
6106
6107 #[test]
6112 fn test_lrn_output_shape() {
6113 let lrn = LocalResponseNorm::new(5, 1e-4, 0.75, 1.0).unwrap();
6114 let input = Tensor::<f32>::from_storage(
6115 TensorStorage::cpu(vec![1.0f32; 2 * 4 * 3 * 3]),
6116 vec![2, 4, 3, 3],
6117 false,
6118 )
6119 .unwrap();
6120 let output = Module::<f32>::forward(&lrn, &input).unwrap();
6121 assert_eq!(output.shape(), &[2, 4, 3, 3]);
6122 }
6123
6124 #[test]
6125 fn test_lrn_3d_input() {
6126 let lrn = LocalResponseNorm::new(3, 1e-4, 0.75, 1.0).unwrap();
6127 let input = Tensor::<f32>::from_storage(
6128 TensorStorage::cpu(vec![1.0f32; 2 * 4 * 8]),
6129 vec![2, 4, 8],
6130 false,
6131 )
6132 .unwrap();
6133 let output = Module::<f32>::forward(&lrn, &input).unwrap();
6134 assert_eq!(output.shape(), &[2, 4, 8]);
6135 }
6136
6137 #[test]
6138 fn test_lrn_rejects_2d() {
6139 let lrn = LocalResponseNorm::new(3, 1e-4, 0.75, 1.0).unwrap();
6140 let input =
6141 Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0f32; 8]), vec![2, 4], false)
6142 .unwrap();
6143 assert!(Module::<f32>::forward(&lrn, &input).is_err());
6144 }
6145
6146 #[test]
6147 fn test_lrn_zero_size_rejected() {
6148 assert!(LocalResponseNorm::new(0, 1e-4, 0.75, 1.0).is_err());
6149 }
6150
6151 #[test]
6152 fn test_lrn_default_params() {
6153 let lrn = LocalResponseNorm::default_params(5).unwrap();
6154 assert_eq!(lrn.size, 5);
6155 assert!((lrn.alpha - 1e-4).abs() < 1e-10);
6156 assert!((lrn.beta - 0.75).abs() < 1e-10);
6157 assert!((lrn.k - 1.0).abs() < 1e-10);
6158 }
6159
6160 #[test]
6161 fn test_lrn_no_parameters() {
6162 let lrn = LocalResponseNorm::new(5, 1e-4, 0.75, 1.0).unwrap();
6163 assert!(Module::<f32>::parameters(&lrn).is_empty());
6164 }
6165
6166 #[test]
6167 fn test_lrn_divides_by_norm() {
6168 let lrn = LocalResponseNorm::new(3, 10.0, 1.0, 1.0).unwrap();
6171 let data: Vec<f32> = vec![1.0; 3 * 2];
6172 let input =
6173 Tensor::<f32>::from_storage(TensorStorage::cpu(data), vec![1, 3, 2], false).unwrap();
6174 let output = Module::<f32>::forward(&lrn, &input).unwrap();
6175 let out_data = output.data().unwrap();
6176
6177 for &v in out_data.iter() {
6179 assert!(
6180 v < 1.0 && v > 0.0,
6181 "LRN output {v} should be attenuated (0 < v < 1)"
6182 );
6183 }
6184 }
6185
6186 #[test]
6187 fn test_lrn_backward_numerical() {
6188 let lrn = LocalResponseNorm::new(3, 1e-4, 0.75, 1.0).unwrap();
6189 let input_data: Vec<f64> = vec![
6190 1.0, -0.5, 2.0, 0.3, 0.7, -1.2, 0.4, 1.5, -0.3, 0.8, 1.1, -0.7,
6191 ];
6192 let shape = vec![1usize, 3, 4];
6193
6194 let input = leaf(&input_data, &shape, true);
6195 let output = Module::<f64>::forward(&lrn, &input).unwrap();
6196 let out_data = output.data().unwrap().to_vec();
6197 let total: f64 = out_data.iter().sum();
6198
6199 let sum_gf = Arc::new(SumBackwardHelper {
6200 input: output.clone(),
6201 });
6202 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
6203 loss.backward().unwrap();
6204
6205 let analytic_grad = input.grad().unwrap().unwrap();
6206 let analytic = analytic_grad.data().unwrap().to_vec();
6207
6208 let h = 1e-6;
6209 for i in 0..input_data.len() {
6210 let mut data_plus = input_data.clone();
6211 data_plus[i] += h;
6212 let inp_plus = leaf(&data_plus, &shape, false);
6213 let out_plus = no_grad(|| Module::<f64>::forward(&lrn, &inp_plus)).unwrap();
6214 let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
6215
6216 let mut data_minus = input_data.clone();
6217 data_minus[i] -= h;
6218 let inp_minus = leaf(&data_minus, &shape, false);
6219 let out_minus = no_grad(|| Module::<f64>::forward(&lrn, &inp_minus)).unwrap();
6220 let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
6221
6222 let numerical = (sum_plus - sum_minus) / (2.0 * h);
6223 assert!(
6224 (numerical - analytic[i]).abs() < 1e-4,
6225 "LRN grad[{i}]: numerical={numerical}, analytic={}",
6226 analytic[i]
6227 );
6228 }
6229 }
6230
6231 #[test]
6236 fn test_instancenorm1d_output_shape() {
6237 let norm = InstanceNorm1d::<f32>::new(3, 1e-5, true).unwrap();
6238 let input =
6240 Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 48]), vec![2, 3, 8], false)
6241 .unwrap();
6242 let out = norm.forward(&input).unwrap();
6243 assert_eq!(out.shape(), &[2, 3, 8]);
6244 }
6245
6246 #[test]
6247 fn test_instancenorm1d_rejects_wrong_ndim() {
6248 let norm = InstanceNorm1d::<f32>::new(3, 1e-5, true).unwrap();
6249 let input = Tensor::from_storage(
6251 TensorStorage::cpu(vec![1.0f32; 48]),
6252 vec![2, 3, 4, 2],
6253 false,
6254 )
6255 .unwrap();
6256 assert!(norm.forward(&input).is_err());
6257 }
6258
6259 #[test]
6260 fn test_instancenorm2d_normalizes_per_instance_channel() {
6261 let norm = InstanceNorm2d::<f32>::new(2, 1e-5, true).unwrap();
6263 let data: Vec<f32> = vec![
6265 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ];
6268 let input =
6269 Tensor::from_storage(TensorStorage::cpu(data), vec![1, 2, 2, 2], false).unwrap();
6270 let out = norm.forward(&input).unwrap();
6271 let d = out.data().unwrap();
6272
6273 for c in 0..2 {
6275 let start = c * 4;
6276 let end = start + 4;
6277 let slice = &d[start..end];
6278 let mean: f32 = slice.iter().sum::<f32>() / 4.0;
6279 let var: f32 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / 4.0;
6280 assert!(mean.abs() < 1e-5, "channel {c} mean = {mean}, expected ~0");
6281 assert!(
6282 (var - 1.0).abs() < 0.1,
6283 "channel {c} var = {var}, expected ~1"
6284 );
6285 }
6286 }
6287
6288 #[test]
6289 fn test_instancenorm2d_rejects_wrong_ndim() {
6290 let norm = InstanceNorm2d::<f32>::new(3, 1e-5, true).unwrap();
6291 let input =
6293 Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 24]), vec![2, 3, 4], false)
6294 .unwrap();
6295 assert!(norm.forward(&input).is_err());
6296 }
6297
6298 #[test]
6299 fn test_instancenorm3d_output_shape() {
6300 let norm = InstanceNorm3d::<f32>::new(2, 1e-5, false).unwrap();
6301 let input = Tensor::from_storage(
6303 TensorStorage::cpu(vec![1.0f32; 16]),
6304 vec![1, 2, 2, 2, 2],
6305 false,
6306 )
6307 .unwrap();
6308 let out = norm.forward(&input).unwrap();
6309 assert_eq!(out.shape(), &[1, 2, 2, 2, 2]);
6310 }
6311
6312 #[test]
6313 fn test_instancenorm2d_no_affine_no_params() {
6314 let norm = InstanceNorm2d::<f32>::new(4, 1e-5, false).unwrap();
6315 assert!(Module::<f32>::parameters(&norm).is_empty());
6316 }
6317
6318 #[test]
6319 fn test_instancenorm2d_has_affine_params() {
6320 let norm = InstanceNorm2d::<f32>::new(4, 1e-5, true).unwrap();
6321 let params = Module::<f32>::parameters(&norm);
6322 assert_eq!(params.len(), 2);
6323 assert_eq!(params[0].shape(), &[4]); assert_eq!(params[1].shape(), &[4]); }
6326
6327 #[test]
6328 fn test_instancenorm2d_backward_gradient_check() {
6329 let h = 1e-7;
6330 let num_features = 2;
6331 let input_data: Vec<f64> = vec![1.0, -0.5, 2.0, 0.3, 0.7, -1.2, 0.4, 1.5];
6333 let shape = vec![1usize, 2, 2, 2];
6334
6335 let norm = InstanceNorm2d::<f64>::new(num_features, 1e-5, true).unwrap();
6336
6337 let input = leaf(&input_data, &shape, true);
6339 let output = norm.forward(&input).unwrap();
6340 let out_data = output.data().unwrap().to_vec();
6341 let total: f64 = out_data.iter().sum();
6342
6343 let sum_gf = Arc::new(SumBackwardHelper {
6344 input: output.clone(),
6345 });
6346 let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
6347 loss.backward().unwrap();
6348
6349 let analytic_grad = input.grad().unwrap().unwrap();
6350 let analytic = analytic_grad.data().unwrap().to_vec();
6351
6352 for i in 0..input_data.len() {
6354 let mut data_plus = input_data.clone();
6355 data_plus[i] += h;
6356 let inp_plus = leaf(&data_plus, &shape, false);
6357 let out_plus = no_grad(|| norm.forward(&inp_plus)).unwrap();
6358 let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
6359
6360 let mut data_minus = input_data.clone();
6361 data_minus[i] -= h;
6362 let inp_minus = leaf(&data_minus, &shape, false);
6363 let out_minus = no_grad(|| norm.forward(&inp_minus)).unwrap();
6364 let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
6365
6366 let numerical = (sum_plus - sum_minus) / (2.0 * h);
6367 assert!(
6368 (numerical - analytic[i]).abs() < 1e-4,
6369 "InstanceNorm2d grad[{i}]: numerical={numerical}, analytic={}",
6370 analytic[i]
6371 );
6372 }
6373 }
6374
6375 #[test]
6376 fn test_instancenorm_zero_features_rejected() {
6377 assert!(InstanceNorm1d::<f32>::new(0, 1e-5, true).is_err());
6378 assert!(InstanceNorm2d::<f32>::new(0, 1e-5, true).is_err());
6379 assert!(InstanceNorm3d::<f32>::new(0, 1e-5, true).is_err());
6380 }
6381
6382 #[test]
6403 fn bn2d_set_running_mean_round_trip() {
6404 let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6405 let v: [f32; 4] = [0.5, -1.25, 2.0, 0.0];
6406 bn.set_running_mean(&v).unwrap();
6407 let got = bn.running_mean();
6408 assert_eq!(got.len(), 4);
6409 for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6410 assert!(
6411 (g - *e as f64).abs() < 1e-7,
6412 "channel {i}: got={g}, expected={e}"
6413 );
6414 }
6415 }
6416
6417 #[test]
6418 fn bn2d_set_running_var_round_trip() {
6419 let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6420 let v: [f32; 4] = [1.5, 0.25, 2.0, 0.0];
6421 bn.set_running_var(&v).unwrap();
6422 let got = bn.running_var();
6423 for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6424 assert!(
6425 (g - *e as f64).abs() < 1e-7,
6426 "channel {i}: got={g}, expected={e}"
6427 );
6428 }
6429 }
6430
6431 #[test]
6432 fn bn2d_set_num_batches_tracked_round_trip() {
6433 let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6434 bn.set_num_batches_tracked(42).unwrap();
6435 assert_eq!(bn.num_batches_tracked(), 42);
6436 bn.set_num_batches_tracked(0).unwrap();
6438 assert_eq!(bn.num_batches_tracked(), 0);
6439 }
6440
6441 #[test]
6442 fn bn1d_set_running_mean_round_trip() {
6443 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6444 let v: [f32; 3] = [-0.1, 0.2, 0.3];
6445 bn.set_running_mean(&v).unwrap();
6446 let got = bn.running_mean();
6447 for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6448 assert!(
6449 (g - *e as f64).abs() < 1e-7,
6450 "channel {i}: got={g}, expected={e}"
6451 );
6452 }
6453 }
6454
6455 #[test]
6456 fn bn1d_set_running_var_round_trip() {
6457 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6458 let v: [f32; 3] = [0.5, 1.0, 4.0];
6459 bn.set_running_var(&v).unwrap();
6460 let got = bn.running_var();
6461 for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6462 assert!(
6463 (g - *e as f64).abs() < 1e-7,
6464 "channel {i}: got={g}, expected={e}"
6465 );
6466 }
6467 }
6468
6469 #[test]
6470 fn bn1d_set_num_batches_tracked_round_trip() {
6471 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6472 bn.set_num_batches_tracked(7).unwrap();
6473 assert_eq!(bn.num_batches_tracked(), 7);
6474 }
6475
6476 #[test]
6477 fn bn3d_set_running_mean_round_trip() {
6478 let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6479 let v: [f32; 2] = [-2.0, 3.0];
6480 bn.set_running_mean(&v).unwrap();
6481 let got = bn.running_mean();
6482 for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6483 assert!(
6484 (g - *e as f64).abs() < 1e-7,
6485 "channel {i}: got={g}, expected={e}"
6486 );
6487 }
6488 }
6489
6490 #[test]
6491 fn bn3d_set_running_var_round_trip() {
6492 let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6493 let v: [f32; 2] = [1.0, 0.5];
6494 bn.set_running_var(&v).unwrap();
6495 let got = bn.running_var();
6496 for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6497 assert!(
6498 (g - *e as f64).abs() < 1e-7,
6499 "channel {i}: got={g}, expected={e}"
6500 );
6501 }
6502 }
6503
6504 #[test]
6505 fn bn3d_set_num_batches_tracked_round_trip() {
6506 let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6507 bn.set_num_batches_tracked(11).unwrap();
6508 assert_eq!(bn.num_batches_tracked(), 11);
6509 }
6510
6511 #[test]
6514 fn bn2d_set_running_mean_rejects_wrong_length() {
6515 let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6516 let too_short: [f32; 3] = [0.0, 0.0, 0.0];
6517 let err = bn
6518 .set_running_mean(&too_short)
6519 .expect_err("wrong length should error");
6520 match err {
6521 FerrotorchError::ShapeMismatch { message } => {
6522 assert!(message.contains("BatchNorm2d::set_running_mean"));
6523 assert!(message.contains("num_features=4"));
6524 }
6525 other => panic!("expected ShapeMismatch, got {other:?}"),
6526 }
6527 let too_long: [f32; 5] = [0.0; 5];
6528 bn.set_running_mean(&too_long)
6529 .expect_err("wrong length should error");
6530 }
6531
6532 #[test]
6533 fn bn2d_set_running_mean_rejects_non_finite() {
6534 let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6535 let nan_val: [f32; 2] = [0.0, f32::NAN];
6536 match bn.set_running_mean(&nan_val).expect_err("nan should error") {
6537 FerrotorchError::InvalidArgument { message } => {
6538 assert!(message.contains("non-finite"));
6539 assert!(message.contains("index 1"));
6540 }
6541 other => panic!("expected InvalidArgument, got {other:?}"),
6542 }
6543 let inf_val: [f32; 2] = [f32::INFINITY, 0.0];
6544 bn.set_running_mean(&inf_val).expect_err("inf should error");
6545 }
6546
6547 #[test]
6548 fn bn2d_set_running_var_rejects_negative() {
6549 let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6550 let v: [f32; 2] = [1.0, -0.5];
6551 match bn
6552 .set_running_var(&v)
6553 .expect_err("negative variance should error")
6554 {
6555 FerrotorchError::InvalidArgument { message } => {
6556 assert!(message.contains("negative"));
6557 assert!(message.contains("index 1"));
6558 }
6559 other => panic!("expected InvalidArgument, got {other:?}"),
6560 }
6561 }
6562
6563 #[test]
6564 fn bn2d_set_running_var_rejects_non_finite() {
6565 let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6566 let v: [f32; 2] = [1.0, f32::NAN];
6567 match bn
6568 .set_running_var(&v)
6569 .expect_err("nan variance should error")
6570 {
6571 FerrotorchError::InvalidArgument { message } => {
6572 assert!(message.contains("non-finite"));
6573 }
6574 other => panic!("expected InvalidArgument, got {other:?}"),
6575 }
6576 }
6577
6578 #[test]
6579 fn bn1d_set_running_var_rejects_negative_and_wrong_length() {
6580 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6581 let bad_neg: [f32; 3] = [1.0, -1.0, 1.0];
6582 bn.set_running_var(&bad_neg)
6583 .expect_err("negative should error");
6584 let bad_len: [f32; 2] = [1.0, 1.0];
6585 bn.set_running_var(&bad_len)
6586 .expect_err("wrong length should error");
6587 }
6588
6589 #[test]
6590 fn bn3d_set_running_mean_rejects_non_finite_and_wrong_length() {
6591 let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6592 let bad_nan: [f32; 2] = [f32::NAN, 0.0];
6593 bn.set_running_mean(&bad_nan).expect_err("nan should error");
6594 let bad_len: [f32; 1] = [0.0];
6595 bn.set_running_mean(&bad_len)
6596 .expect_err("wrong length should error");
6597 }
6598
6599 #[test]
6610 fn bn2d_set_running_stats_flow_through_eval_forward() {
6611 let mut bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6613 bn.set_running_mean(&[3.0_f32, -2.0]).unwrap();
6614 bn.set_running_var(&[4.0_f32, 0.25]).unwrap();
6615 bn.eval();
6616
6617 let input = Tensor::from_storage(
6621 TensorStorage::cpu(vec![5.0_f32, 1.0]),
6622 vec![1, 2, 1, 1],
6623 false,
6624 )
6625 .unwrap();
6626 let out = bn.forward(&input).unwrap();
6627 let data = out.data_vec().unwrap();
6628 assert_eq!(data.len(), 2);
6629
6630 let expected_0 = 2.0_f32 / (4.0_f32 + 1e-5).sqrt();
6631 let expected_1 = 3.0_f32 / (0.25_f32 + 1e-5).sqrt();
6632
6633 assert!(
6634 (data[0] - expected_0).abs() < 1e-5,
6635 "channel 0: got {}, expected {}",
6636 data[0],
6637 expected_0
6638 );
6639 assert!(
6640 (data[1] - expected_1).abs() < 1e-5,
6641 "channel 1: got {}, expected {}",
6642 data[1],
6643 expected_1
6644 );
6645
6646 assert!(
6652 (data[0] - 5.0_f32).abs() > 1.0,
6653 "data[0]={} too close to default-stats output (5.0); setter is not flowing through",
6654 data[0]
6655 );
6656 assert!(
6657 (data[1] - 1.0_f32).abs() > 1.0,
6658 "data[1]={} too close to default-stats output (1.0); setter is not flowing through",
6659 data[1]
6660 );
6661 }
6662
6663 #[test]
6664 fn bn1d_set_running_stats_flow_through_eval_forward() {
6665 let mut bn = BatchNorm1d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6667 bn.set_running_mean(&[1.0_f32, 0.0]).unwrap();
6668 bn.set_running_var(&[9.0_f32, 4.0]).unwrap();
6669 bn.eval();
6670
6671 let input = Tensor::from_storage(TensorStorage::cpu(vec![4.0_f32, 6.0]), vec![1, 2], false)
6672 .unwrap();
6673 let out = bn.forward(&input).unwrap();
6674 let data = out.data_vec().unwrap();
6675
6676 let expected_0 = 3.0_f32 / (9.0_f32 + 1e-5).sqrt();
6679 let expected_1 = 6.0_f32 / (4.0_f32 + 1e-5).sqrt();
6680 assert!(
6681 (data[0] - expected_0).abs() < 1e-5,
6682 "BN1d ch0: got {}, expected {}",
6683 data[0],
6684 expected_0
6685 );
6686 assert!(
6687 (data[1] - expected_1).abs() < 1e-5,
6688 "BN1d ch1: got {}, expected {}",
6689 data[1],
6690 expected_1
6691 );
6692 }
6693
6694 #[test]
6695 fn bn3d_set_running_stats_flow_through_eval_forward() {
6696 let mut bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6698 bn.set_running_mean(&[2.0_f32, -1.0]).unwrap();
6699 bn.set_running_var(&[1.0_f32, 16.0]).unwrap();
6700 bn.eval();
6701
6702 let input = Tensor::from_storage(
6703 TensorStorage::cpu(vec![3.0_f32, 7.0]),
6704 vec![1, 2, 1, 1, 1],
6705 false,
6706 )
6707 .unwrap();
6708 let out = bn.forward(&input).unwrap();
6709 let data = out.data_vec().unwrap();
6710
6711 let expected_0 = 1.0_f32 / (1.0_f32 + 1e-5).sqrt();
6712 let expected_1 = 8.0_f32 / (16.0_f32 + 1e-5).sqrt();
6713 assert!(
6714 (data[0] - expected_0).abs() < 1e-5,
6715 "BN3d ch0: got {}, expected {}",
6716 data[0],
6717 expected_0
6718 );
6719 assert!(
6720 (data[1] - expected_1).abs() < 1e-5,
6721 "BN3d ch1: got {}, expected {}",
6722 data[1],
6723 expected_1
6724 );
6725 }
6726
6727 #[test]
6730 fn bn2d_as_any_downcasts_to_concrete_type() {
6731 let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6732 let dyn_module: &dyn Module<f32> = &bn;
6733 let any = dyn_module
6734 .as_any()
6735 .expect("BatchNorm2d::as_any returns Some");
6736 let concrete = any
6737 .downcast_ref::<BatchNorm2d<f32>>()
6738 .expect("any must downcast to BatchNorm2d<f32>");
6739 assert_eq!(concrete.num_features, 4);
6740 assert!(any.downcast_ref::<BatchNorm1d<f32>>().is_none());
6742 assert!(any.downcast_ref::<BatchNorm3d<f32>>().is_none());
6743 }
6744
6745 #[test]
6746 fn bn1d_as_any_downcasts_to_concrete_type() {
6747 let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6748 let dyn_module: &dyn Module<f32> = &bn;
6749 let any = dyn_module
6750 .as_any()
6751 .expect("BatchNorm1d::as_any returns Some");
6752 let concrete = any
6753 .downcast_ref::<BatchNorm1d<f32>>()
6754 .expect("any must downcast to BatchNorm1d<f32>");
6755 assert_eq!(concrete.num_features, 3);
6756 }
6757
6758 #[test]
6759 fn bn3d_as_any_downcasts_to_concrete_type() {
6760 let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6761 let dyn_module: &dyn Module<f32> = &bn;
6762 let any = dyn_module
6763 .as_any()
6764 .expect("BatchNorm3d::as_any returns Some");
6765 let concrete = any
6766 .downcast_ref::<BatchNorm3d<f32>>()
6767 .expect("any must downcast to BatchNorm3d<f32>");
6768 assert_eq!(concrete.num_features, 2);
6769 }
6770
6771 #[test]
6775 fn non_bn_module_as_any_returns_none() {
6776 let ln = LayerNorm::<f32>::new(vec![4], 1e-5, true).unwrap();
6777 let dyn_module: &dyn Module<f32> = &ln;
6778 assert!(
6779 dyn_module.as_any().is_none(),
6780 "non-BN modules must not opt into the as_any downcast hook"
6781 );
6782 }
6783
6784 #[test]
6788 fn bn2d_set_running_mean_does_not_touch_nbt() {
6789 let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6790 assert_eq!(bn.num_batches_tracked(), 0);
6791 bn.set_running_mean(&[0.5_f32, 0.5]).unwrap();
6792 assert_eq!(bn.num_batches_tracked(), 0);
6793 bn.set_running_var(&[1.0_f32, 1.0]).unwrap();
6794 assert_eq!(bn.num_batches_tracked(), 0);
6795 }
6796
6797 #[test]
6805 #[cfg(feature = "cuda")]
6806 #[ignore = "needs CUDA hardware; tracking #1356/#1357"]
6807 fn group_norm_forward_gpu_matches_cpu() {
6808 use crate::module::Module as _;
6809 use ferrotorch_core::Device;
6810 use ferrotorch_gpu::init_cuda_backend;
6811 init_cuda_backend().expect("CUDA init failed");
6812
6813 let b = 2;
6815 let c = 8;
6816 let h = 3;
6817 let w = 4;
6818 let n = b * c * h * w;
6819 let data: Vec<f32> = (0..n).map(|k| ((k % 17) as f32) * 0.13 - 1.1).collect();
6820
6821 let gamma: Vec<f32> = (0..c).map(|k| 1.0 + 0.05 * (k as f32)).collect();
6824 let beta: Vec<f32> = (0..c).map(|k| -0.1 + 0.02 * (k as f32)).collect();
6825 let mut gn = GroupNorm::<f32>::new(4, c, 1e-5, true).unwrap();
6826 gn.weight
6827 .set_data(Tensor::from_storage(TensorStorage::cpu(gamma), vec![c], false).unwrap());
6828 gn.bias
6829 .set_data(Tensor::from_storage(TensorStorage::cpu(beta), vec![c], false).unwrap());
6830
6831 let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![b, c, h, w], false)
6832 .unwrap();
6833 let y_cpu = gn.forward(&x_cpu).unwrap();
6834 let cpu_vals = y_cpu.data().unwrap().to_vec();
6835
6836 gn.to_device(Device::Cuda(0)).unwrap();
6839 let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
6840 let y_gpu = gn.forward(&x_gpu).unwrap();
6841 assert!(y_gpu.is_cuda(), "GroupNorm GPU output must stay on CUDA");
6842 let gpu_vals = y_gpu.data_vec().unwrap();
6843
6844 assert_eq!(gpu_vals.len(), cpu_vals.len());
6845 let mut max_abs = 0.0f32;
6846 for (g, c) in gpu_vals.iter().zip(cpu_vals.iter()) {
6847 max_abs = max_abs.max((g - c).abs());
6848 }
6849 assert!(max_abs < 1e-4, "GroupNorm GPU vs CPU max|Δ| = {max_abs}");
6850 }
6851
6852 #[test]
6857 #[cfg(feature = "cuda")]
6858 fn batch_norm2d_eval_forward_gpu_matches_cpu() {
6859 use crate::module::Module as _;
6860 use ferrotorch_core::Device;
6861 use ferrotorch_gpu::init_cuda_backend;
6862 if init_cuda_backend().is_err() {
6863 return;
6864 }
6865
6866 let (b, c, h, w) = (2usize, 6usize, 4usize, 5usize);
6867 let n = b * c * h * w;
6868 let data: Vec<f32> = (0..n).map(|k| ((k % 19) as f32) * 0.11 - 1.0).collect();
6869 let gamma: Vec<f32> = (0..c).map(|k| 1.0 + 0.07 * (k as f32)).collect();
6870 let beta: Vec<f32> = (0..c).map(|k| -0.2 + 0.03 * (k as f32)).collect();
6871 let rmean: Vec<f32> = (0..c).map(|k| 0.05 * (k as f32) - 0.1).collect();
6872 let rvar: Vec<f32> = (0..c).map(|k| 0.8 + 0.05 * (k as f32)).collect();
6873
6874 let make = || {
6875 let mut bn = BatchNorm2d::<f32>::new(c, 1e-5, 0.1, true).unwrap();
6876 bn.weight.as_mut().unwrap().set_data(
6877 Tensor::from_storage(TensorStorage::cpu(gamma.clone()), vec![c], false).unwrap(),
6878 );
6879 bn.bias.as_mut().unwrap().set_data(
6880 Tensor::from_storage(TensorStorage::cpu(beta.clone()), vec![c], false).unwrap(),
6881 );
6882 bn.set_running_mean(&rmean).unwrap();
6883 bn.set_running_var(&rvar).unwrap();
6884 bn.eval();
6885 bn
6886 };
6887
6888 let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![b, c, h, w], false)
6889 .unwrap();
6890 let bn_cpu = make();
6891 let cpu_vals = bn_cpu.forward(&x_cpu).unwrap().data().unwrap().to_vec();
6892
6893 let mut bn_gpu = make();
6894 bn_gpu.to_device(Device::Cuda(0)).unwrap();
6895 let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
6896 let y_gpu = bn_gpu.forward(&x_gpu).unwrap();
6897 assert!(y_gpu.is_cuda(), "BatchNorm2d GPU output must stay on CUDA");
6898 let gpu_vals = y_gpu.data_vec().unwrap();
6899
6900 assert_eq!(gpu_vals.len(), cpu_vals.len());
6901 let mut max_abs = 0.0f32;
6902 for (g, cv) in gpu_vals.iter().zip(cpu_vals.iter()) {
6903 max_abs = max_abs.max((g - cv).abs());
6904 }
6905 assert!(
6906 max_abs < 1e-4,
6907 "BatchNorm2d eval GPU vs CPU max|Δ| = {max_abs}"
6908 );
6909 }
6910
6911 #[test]
6915 #[cfg(feature = "cuda")]
6916 fn batch_norm2d_train_forward_gpu_matches_cpu() {
6917 use crate::module::Module as _;
6918 use ferrotorch_core::Device;
6919 use ferrotorch_gpu::init_cuda_backend;
6920 if init_cuda_backend().is_err() {
6921 return;
6922 }
6923
6924 let (b, c, h, w) = (4usize, 5usize, 3usize, 3usize);
6925 let n = b * c * h * w;
6926 let data: Vec<f32> = (0..n).map(|k| ((k as f32) * 0.037).sin() * 1.3).collect();
6927 let gamma: Vec<f32> = (0..c).map(|k| 0.9 + 0.04 * (k as f32)).collect();
6928 let beta: Vec<f32> = (0..c).map(|k| 0.02 * (k as f32)).collect();
6929
6930 let make = || {
6931 let mut bn = BatchNorm2d::<f32>::new(c, 1e-5, 0.1, true).unwrap();
6932 bn.weight.as_mut().unwrap().set_data(
6933 Tensor::from_storage(TensorStorage::cpu(gamma.clone()), vec![c], false).unwrap(),
6934 );
6935 bn.bias.as_mut().unwrap().set_data(
6936 Tensor::from_storage(TensorStorage::cpu(beta.clone()), vec![c], false).unwrap(),
6937 );
6938 bn };
6940
6941 let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![b, c, h, w], false)
6942 .unwrap();
6943 let bn_cpu = make();
6944 let cpu_vals = bn_cpu.forward(&x_cpu).unwrap().data().unwrap().to_vec();
6945 let cpu_rmean = bn_cpu.running_mean();
6946 let cpu_rvar = bn_cpu.running_var();
6947
6948 let mut bn_gpu = make();
6949 bn_gpu.to_device(Device::Cuda(0)).unwrap();
6950 let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
6951 let y_gpu = bn_gpu.forward(&x_gpu).unwrap();
6952 assert!(y_gpu.is_cuda());
6953 let gpu_vals = y_gpu.data_vec().unwrap();
6954
6955 let mut max_abs = 0.0f32;
6956 for (g, cv) in gpu_vals.iter().zip(cpu_vals.iter()) {
6957 max_abs = max_abs.max((g - cv).abs());
6958 }
6959 assert!(
6960 max_abs < 1e-4,
6961 "BatchNorm2d train GPU vs CPU max|Δ| = {max_abs}"
6962 );
6963
6964 let gpu_rmean = bn_gpu.running_mean();
6966 let gpu_rvar = bn_gpu.running_var();
6967 assert_eq!(bn_gpu.num_batches_tracked(), 1);
6968 for cc in 0..c {
6969 assert!(
6970 (gpu_rmean[cc] - cpu_rmean[cc]).abs() < 1e-4,
6971 "running_mean[{cc}] gpu={} cpu={}",
6972 gpu_rmean[cc],
6973 cpu_rmean[cc]
6974 );
6975 assert!(
6976 (gpu_rvar[cc] - cpu_rvar[cc]).abs() < 1e-4,
6977 "running_var[{cc}] gpu={} cpu={}",
6978 gpu_rvar[cc],
6979 cpu_rvar[cc]
6980 );
6981 }
6982 }
6983
6984 #[test]
6987 #[cfg(feature = "cuda")]
6988 fn instance_norm2d_forward_gpu_matches_cpu() {
6989 use crate::module::Module as _;
6990 use ferrotorch_core::Device;
6991 use ferrotorch_gpu::init_cuda_backend;
6992 if init_cuda_backend().is_err() {
6993 return;
6994 }
6995
6996 let (b, c, h, w) = (3usize, 4usize, 5usize, 4usize);
6997 let n = b * c * h * w;
6998 let data: Vec<f32> = (0..n).map(|k| ((k % 23) as f32) * 0.09 - 0.8).collect();
6999 let gamma: Vec<f32> = (0..c).map(|k| 1.0 + 0.06 * (k as f32)).collect();
7000 let beta: Vec<f32> = (0..c).map(|k| -0.05 + 0.04 * (k as f32)).collect();
7001
7002 let make = || {
7003 let mut inorm = InstanceNorm2d::<f32>::new(c, 1e-5, true).unwrap();
7004 inorm.inner.weight.set_data(
7005 Tensor::from_storage(TensorStorage::cpu(gamma.clone()), vec![c], false).unwrap(),
7006 );
7007 inorm.inner.bias.set_data(
7008 Tensor::from_storage(TensorStorage::cpu(beta.clone()), vec![c], false).unwrap(),
7009 );
7010 inorm
7011 };
7012
7013 let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![b, c, h, w], false)
7014 .unwrap();
7015 let cpu_vals = make().forward(&x_cpu).unwrap().data().unwrap().to_vec();
7016
7017 let mut gpu = make();
7018 gpu.to_device(Device::Cuda(0)).unwrap();
7019 let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
7020 let y_gpu = gpu.forward(&x_gpu).unwrap();
7021 assert!(
7022 y_gpu.is_cuda(),
7023 "InstanceNorm2d GPU output must stay on CUDA"
7024 );
7025 let gpu_vals = y_gpu.data_vec().unwrap();
7026
7027 let mut max_abs = 0.0f32;
7028 for (g, cv) in gpu_vals.iter().zip(cpu_vals.iter()) {
7029 max_abs = max_abs.max((g - cv).abs());
7030 }
7031 assert!(
7032 max_abs < 1e-4,
7033 "InstanceNorm2d GPU vs CPU max|Δ| = {max_abs}"
7034 );
7035 }
7036}