Skip to main content

ferrotorch_nn/
norm.rs

1//! Normalization layers: LayerNorm, GroupNorm, RMSNorm, BatchNorm1d/2d/3d,
2//! InstanceNorm1d/2d/3d, LocalResponseNorm.
3//!
4//! Each layer normalizes its input along specified dimensions and optionally
5//! applies a learnable affine transform (weight/bias). Backward functions
6//! implement `GradFn<T>` to propagate gradients through the normalization.
7//!
8//! ## REQ status (per `.design/ferrotorch-nn/norm.md`)
9//!
10//! | REQ | Status | Evidence |
11//! |---|---|---|
12//! | REQ-1 | SHIPPED | `pub struct LayerNorm<T: Float>` + `impl<T: Float> Module<T> for LayerNorm<T>` mirrors `torch/nn/modules/normalization.py:105-238`; consumed by `ferrotorch-bert/src/attention.rs:18,194,209` (`pub layer_norm: LayerNorm<T>`) and `ferrotorch-whisper/src/layer.rs:12`. Runner-arm: #1447. |
13//! | REQ-2 | SHIPPED | GPU fast path inside `LayerNorm::forward` dispatches to `backend.layernorm_f32(input, weight, bias, batch, norm_size, eps)` when `input.is_cuda() && self.elementwise_affine` mirroring `aten/src/ATen/native/Normalization.cpp` `native_layer_norm_cuda`; consumed by `ferrotorch-bert/src/attention.rs:209` and `ferrotorch-whisper/src/encoder.rs:27` pushing GPU-resident inputs through this path during inference. |
14//! | REQ-3 | SHIPPED | `pub struct GroupNorm<T: Float>` + `GroupNormBackward<T>` mirrors `torch/nn/modules/normalization.py:239-342`; consumed by `ferrotorch-diffusion/src/vae.rs:39,99` (`pub conv_norm_out: GroupNorm<T>`). GPU forward fast path (#1357): `GroupNorm::forward` dispatches to `backend.group_norm_f32(...)` for CUDA input (kernel in `ferrotorch-gpu/src/group_norm.rs`). Runner-arm: #1447. |
15//! | REQ-4 | SHIPPED | `pub struct RMSNorm<T: Float>` + `RMSNormBackward<T>` with `mean(x^2)` denominator (no centering, no bias) mirrors `torch/nn/modules/normalization.py:343-435`; consumed via `ferrotorch-nn/src/lib.rs:227` re-export (Llama / T5 stacks). |
16//! | REQ-5 | SHIPPED | `pub struct BatchNorm2d<T: Float>` + `BatchNorm2dBackward<T>` with per-channel running stats and train/eval branching mirrors `torch/nn/modules/batchnorm.py:420-498`; consumed by `ferrotorch-distributed/src/sync_batch_norm.rs:3,595` and `ferrotorch-vision/src/models/segmentation/fcn.rs:34`. GPU fwd via `batch_norm_gpu_forward`; GPU backward (#1449) via `batch_norm_gpu_backward` → `backend.batch_norm_backward_f32` (kernel `gpu_batch_norm_backward_f32` in `ferrotorch-gpu/src/group_norm.rs`, mirrors `aten/src/ATen/native/cuda/Normalization.cuh:388`), on-device, NO `.cpu()` round trip; live-vs-torch grad parity (<1e-3) pinned by `divergence_batchnorm2d_gpu_{train,eval}_backward_vs_torch`. Runner-arm: #1447. |
17//! | REQ-6 | SHIPPED | `pub struct BatchNorm1d<T: Float>` + `BatchNorm1dBackward<T>` mirrors `torch/nn/modules/batchnorm.py:306-383`; consumed by `ferrotorch-nn/src/lazy_norm.rs:154` (`lazy_batchnorm!(LazyBatchNorm1d, BatchNorm1d, ...)`) and `ferrotorch-nn/src/lib.rs:225-228` re-exports. GPU fwd + backward (#1449) via the shared `batch_norm_gpu_backward` helper, on-device. |
18//! | REQ-7 | SHIPPED | `pub struct BatchNorm3d<T: Float>` + `BatchNorm3dBackward<T>` mirrors `torch/nn/modules/batchnorm.py:535-613`; consumed by `ferrotorch-nn/src/lazy_norm.rs:155` and `lib.rs:225-228` re-exports. GPU fwd + backward (#1449) via the shared `batch_norm_gpu_backward` helper, on-device. |
19//! | REQ-8 | SHIPPED | `pub fn set_running_mean / set_running_var / set_num_batches_tracked` plus matching read accessors on every `BatchNorm{1,2,3}d<T>` with finite + non-negative validation; consumed by `ferrotorch-distributed/src/sync_batch_norm.rs` all-reducing `running_mean()` / `running_var()` and by vision state-dict loaders downcasting via `as_any`. Pinned by `bn{1,2,3}d_set_running_*_round_trip` and `bn{1,2,3}d_set_running_stats_flow_through_eval_forward`. |
20//! | REQ-9 | SHIPPED | `pub struct InstanceNorm1d<T>`, `pub struct InstanceNorm2d<T>`, `pub struct InstanceNorm3d<T>` newtypes around `InstanceNormInner<T>` + `InstanceNormBackward<T>` mirror `torch/nn/modules/instancenorm.py`; consumed by `ferrotorch-nn/src/lazy_norm.rs:14` and `lib.rs:225-228` re-exports. GPU fwd via `group_norm_f32`; GPU backward (#1449) via `instance_norm_gpu_backward` (reshape `[B,C,S]`→`[1,B*C,S]` + `batch_norm_backward_f32` + `sum_axis_f32` reduce), on-device, NO `.cpu()` round trip; live-vs-torch grad parity (<1e-3) pinned by `divergence_instancenorm2d_gpu_backward_vs_torch`. |
21//! | REQ-10 | SHIPPED | `pub struct LocalResponseNorm` + `LocalResponseNormBackward<T>` mirrors `torch/nn/modules/normalization.py:16-73`; consumed via `ferrotorch-nn/src/lib.rs:227` re-export. GPU fwd + backward (#1449) via `backend.local_response_norm_f32` / `local_response_norm_backward_f32` (kernels in `ferrotorch-gpu/src/group_norm.rs`, mirror `torch/nn/functional.py:3032-3046`), on-device with GPU-resident saved `denom`, NO `.cpu()` round trip; live-vs-torch fwd+grad parity (<1e-3) pinned by `divergence_local_response_norm_gpu_fwd_bwd_vs_torch`. |
22//! | REQ-11 | SHIPPED | Every norm layer has `impl<T: Float> Module<T>` with the seven required methods; BatchNorm additionally implements `as_any() -> Option<&dyn Any>` returning `Some(self)`; consumed by `ferrotorch-bert/src/attention.rs` invoking `self.layer_norm.forward(...)` through the trait and by vision state-dict loaders downcasting via `as_any`. Pinned by `bn{1,2,3}d_as_any_downcasts_to_concrete_type`. |
23//! | REQ-12 | SHIPPED | Every forward returns `Tensor::from_operation(storage, shape, grad_fn)` when `is_grad_enabled() && input.requires_grad()`; backward nodes `LayerNormBackward`, `GroupNormBackward`, `RMSNormBackward`, `BatchNorm{1,2,3}dBackward`, `InstanceNormBackward`, `LocalResponseNormBackward` all live in this file; consumed by `ferrotorch-optim` training loops driving `backward()` through these nodes when models composed from `ferrotorch-bert` / `ferrotorch-diffusion` / `ferrotorch-vision` are trained. |
24
25use std::any::TypeId;
26use std::sync::{Arc, Mutex};
27
28use ferrotorch_core::autograd::no_grad::is_grad_enabled;
29use ferrotorch_core::gpu_dispatch::gpu_backend;
30use ferrotorch_core::tensor::GradFn;
31use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
32
33use crate::module::Module;
34use crate::parameter::Parameter;
35
36#[inline]
37fn is_f32<T: Float>() -> bool {
38    TypeId::of::<T>() == TypeId::of::<f32>()
39}
40
41#[inline]
42fn is_f64<T: Float>() -> bool {
43    TypeId::of::<T>() == TypeId::of::<f64>()
44}
45
46/// Shorthand for the unambiguous zero.
47#[inline]
48fn zero<T: Float>() -> T {
49    <T as num_traits::Zero>::zero()
50}
51
52/// Read an f32-tagged GPU buffer handle back to a host `Vec<f32>`.
53///
54/// `GpuBackend::gpu_to_cpu` returns raw bytes; the BatchNorm stat buffers are
55/// f32, so reinterpret the byte stream as little-endian f32. Used only for the
56/// tiny `[channels]` running-stat read-back in [`batch_norm_gpu_forward`].
57fn gpu_handle_to_f32(
58    backend: &dyn ferrotorch_core::gpu_dispatch::GpuBackend,
59    handle: &ferrotorch_core::gpu_dispatch::GpuBufferHandle,
60) -> FerrotorchResult<Vec<f32>> {
61    let bytes = backend.gpu_to_cpu(handle)?;
62    Ok(bytes
63        .chunks_exact(4)
64        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
65        .collect())
66}
67
68/// GPU forward for the BatchNorm family (#1449), shared by
69/// `BatchNorm{1,2,3}d::forward`.
70///
71/// Dispatches to the backend `batch_norm_f32` kernel
72/// (`ferrotorch-gpu/src/group_norm.rs::gpu_batch_norm_f32`), which performs
73/// the per-channel normalization over `(batch, spatial)` and applies the
74/// per-channel affine. `weight` / `bias` are materialized on-device as
75/// `[channels]` buffers (ones / zeros when the layer is non-affine, so the
76/// affine is the identity). In training mode the kernel returns the biased
77/// batch mean / variance, which this helper folds into the `running_mean` /
78/// `running_var` mutexes using the same momentum + Bessel-corrected variance
79/// update as the CPU path. Mirrors
80/// `aten/src/ATen/native/Normalization.cpp::batch_norm_cuda`.
81///
82/// f32-only (the kernel is f32); the caller must gate on `is_f32::<T>()`.
83/// Returns the GPU output buffer handle (input shape).
84#[allow(clippy::too_many_arguments)]
85fn batch_norm_gpu_forward<T: Float>(
86    input: &Tensor<T>,
87    weight: Option<&Tensor<T>>,
88    bias: Option<&Tensor<T>>,
89    running_mean: &Mutex<Vec<f64>>,
90    running_var: &Mutex<Vec<f64>>,
91    num_batches_tracked: &Mutex<usize>,
92    momentum: f64,
93    eps: f64,
94    channels: usize,
95    spatial: usize,
96    is_training: bool,
97) -> FerrotorchResult<Option<ferrotorch_core::gpu_dispatch::GpuBufferHandle>> {
98    let Some(backend) = gpu_backend() else {
99        return Ok(None);
100    };
101    let batch = input.numel() / (channels * spatial.max(1));
102
103    // Per-channel affine, materialized on the same device as `input`.
104    // For the affine case we reuse the layer's f32 weight/bias tensors; for
105    // the non-affine case we build ones/zeros and move them to the device.
106    let weight_dev;
107    let bias_dev;
108    let (w_handle, b_handle) = match (weight, bias) {
109        (Some(w), Some(b)) => (w.gpu_handle()?, b.gpu_handle()?),
110        _ => {
111            weight_dev = ferrotorch_core::creation::ones::<T>(&[channels])?.to(input.device())?;
112            bias_dev = ferrotorch_core::creation::zeros::<T>(&[channels])?.to(input.device())?;
113            (weight_dev.gpu_handle()?, bias_dev.gpu_handle()?)
114        }
115    };
116
117    // Running stats live in f64; the kernel is f32. Push the current running
118    // mean/var to the device (used directly in eval mode, ignored in training
119    // mode where the kernel recomputes them).
120    let rm_snapshot: Vec<f32> = running_mean
121        .lock()
122        .unwrap()
123        .iter()
124        .map(|v| *v as f32)
125        .collect();
126    let rv_snapshot: Vec<f32> = running_var
127        .lock()
128        .unwrap()
129        .iter()
130        .map(|v| *v as f32)
131        .collect();
132    let mean_in = ferrotorch_core::creation::from_slice::<f32>(&rm_snapshot, &[channels])?
133        .to(input.device())?;
134    let var_in = ferrotorch_core::creation::from_slice::<f32>(&rv_snapshot, &[channels])?
135        .to(input.device())?;
136
137    let (out_handle, mean_out, var_out) = backend.batch_norm_f32(
138        input.gpu_handle()?,
139        w_handle,
140        b_handle,
141        mean_in.gpu_handle()?,
142        var_in.gpu_handle()?,
143        batch,
144        channels,
145        spatial.max(1),
146        eps as f32,
147        is_training,
148    )?;
149
150    if is_training {
151        // The kernel wrote the biased batch mean/var into `mean_out`/`var_out`.
152        // Fold them into the running statistics exactly like the CPU path:
153        //   running = (1 - momentum) * running + momentum * batch
154        // with a Bessel correction applied to the variance term.
155        let batch_mean = gpu_handle_to_f32(backend, &mean_out)?;
156        let batch_var = gpu_handle_to_f32(backend, &var_out)?;
157        let count = batch * spatial.max(1);
158        let bessel = if count > 1 {
159            count as f64 / (count as f64 - 1.0)
160        } else {
161            1.0
162        };
163        let mut rm = running_mean.lock().unwrap();
164        let mut rv = running_var.lock().unwrap();
165        let mut nbt = num_batches_tracked.lock().unwrap();
166        *nbt += 1;
167        for c in 0..channels {
168            let bm = batch_mean[c] as f64;
169            let bv = batch_var[c] as f64;
170            rm[c] = (1.0 - momentum) * rm[c] + momentum * bm;
171            rv[c] = (1.0 - momentum) * rv[c] + momentum * bv * bessel;
172        }
173    }
174
175    Ok(Some(out_handle))
176}
177
178/// GPU backward for the BatchNorm / InstanceNorm family (#1449), shared by
179/// `BatchNorm{1,2,3}dBackward` and `InstanceNormBackward`.
180///
181/// Dispatches to the backend `batch_norm_backward_f32` kernel
182/// (`ferrotorch-gpu/src/group_norm.rs::gpu_batch_norm_backward_f32`), which
183/// computes `(grad_input, grad_weight, grad_bias)` entirely on-device — there
184/// is NO `.cpu()` round trip (R-CODE-4). All grad tensors stay GPU-resident.
185/// Mirrors `aten/src/ATen/native/cuda/Normalization.cuh:388
186/// batch_norm_backward_kernel`. f32-only; the caller must gate on `is_f32::<T>()`.
187///
188/// `mean`/`var` are length-`channels` host snapshots used only in eval mode
189/// (`!training`); in training mode the kernel recomputes batch stats from
190/// `input`, so the passed buffers are ignored but must be valid `[channels]`.
191/// `weight_buf` is the affine scale; when the layer is non-affine the caller
192/// passes an all-ones `[channels]` buffer so `grad_scale = invstd`.
193/// `want_weight_grad` / `want_bias_grad` gate whether the (always-computed)
194/// `grad_weight` / `grad_bias` buffers are surfaced to the autograd graph.
195///
196/// Returns `None` when no GPU backend is registered (caller falls back / errors).
197#[allow(clippy::too_many_arguments)]
198#[allow(
199    clippy::fn_params_excessive_bools,
200    reason = "the four flags (training / affine / want_weight_grad / want_bias_grad) each \
201              control a distinct branch of the BatchNorm backward contract and mirror \
202              PyTorch's train arg + grad_input_mask; no meaningful struct grouping exists"
203)]
204fn batch_norm_gpu_backward<T: Float>(
205    input: &Tensor<T>,
206    grad_output: &Tensor<T>,
207    weight_buf: &Tensor<T>,
208    mean: &[f64],
209    var: &[f64],
210    batch: usize,
211    channels: usize,
212    spatial: usize,
213    eps: f64,
214    training: bool,
215    affine: bool,
216    want_weight_grad: bool,
217    want_bias_grad: bool,
218) -> FerrotorchResult<Option<Vec<Option<Tensor<T>>>>> {
219    let Some(backend) = gpu_backend() else {
220        return Ok(None);
221    };
222    let mean_f32: Vec<f32> = mean.iter().map(|v| *v as f32).collect();
223    let var_f32: Vec<f32> = var.iter().map(|v| *v as f32).collect();
224    let mean_dev =
225        ferrotorch_core::creation::from_slice::<f32>(&mean_f32, &[channels])?.to(input.device())?;
226    let var_dev =
227        ferrotorch_core::creation::from_slice::<f32>(&var_f32, &[channels])?.to(input.device())?;
228
229    let (gi_h, gw_h, gb_h) = backend.batch_norm_backward_f32(
230        input.gpu_handle()?,
231        grad_output.gpu_handle()?,
232        weight_buf.gpu_handle()?,
233        mean_dev.gpu_handle()?,
234        var_dev.gpu_handle()?,
235        batch,
236        channels,
237        spatial.max(1),
238        eps as f32,
239        training,
240    )?;
241
242    let grad_input = Tensor::from_storage(TensorStorage::gpu(gi_h), input.shape().to_vec(), false)?;
243    let grad_weight = if affine && want_weight_grad {
244        Some(Tensor::from_storage(
245            TensorStorage::gpu(gw_h),
246            vec![channels],
247            false,
248        )?)
249    } else {
250        None
251    };
252    let grad_bias = if affine && want_bias_grad {
253        Some(Tensor::from_storage(
254            TensorStorage::gpu(gb_h),
255            vec![channels],
256            false,
257        )?)
258    } else {
259        None
260    };
261
262    // The returned grad vec MUST match `BatchNorm{1,2,3}dBackward::inputs()`,
263    // which registers only `input` when `affine == false` (no weight/bias
264    // leaves) and `[input, weight, bias]` when affine. Returning a length-3
265    // vec for a 1-input graph makes the autograd engine reject the node
266    // ("backward returned 3 gradients but expected 1"). Mirrors PyTorch's
267    // `grad_input_mask` in `aten/src/ATen/native/Normalization.cpp:322-330`,
268    // where `grad_weight`/`grad_bias` are simply not produced when their
269    // mask bits are unset.
270    if affine {
271        Ok(Some(vec![Some(grad_input), grad_weight, grad_bias]))
272    } else {
273        Ok(Some(vec![Some(grad_input)]))
274    }
275}
276
277/// GPU backward for InstanceNorm (#1449).
278///
279/// InstanceNorm is BatchNorm applied per-instance: each `(b, c)` slice is its
280/// own normalization group over the spatial dims. We reshape `[B, C, S]` to
281/// `[1, B*C, S]` and reuse the on-device `batch_norm_backward_f32` kernel in
282/// training mode (InstanceNorm always uses instance stats). The per-channel
283/// affine `weight[c]` is tiled to `[B*C]`; the returned `grad_weight` /
284/// `grad_bias` of length `B*C` are summed over the batch axis on-device
285/// (`sum_axis_f32` over a `[B, C]` view) to `[C]`. `grad_input` keeps `[B,C,S]`.
286/// The full gradient-data path stays GPU-resident — NO `.cpu()` round trip
287/// (R-CODE-4). f32-only; the caller gates on `is_f32::<T>()`.
288///
289/// Returns `None` when no GPU backend is registered.
290#[allow(clippy::too_many_arguments)]
291#[allow(
292    clippy::fn_params_excessive_bools,
293    reason = "affine / want_weight_grad / want_bias_grad each gate a distinct branch of the \
294              InstanceNorm backward contract and mirror PyTorch's grad_input_mask; no \
295              meaningful struct grouping exists"
296)]
297fn instance_norm_gpu_backward<T: Float>(
298    input: &Tensor<T>,
299    grad_output: &Tensor<T>,
300    weight: &Tensor<T>,
301    batch: usize,
302    channels: usize,
303    spatial: usize,
304    eps: f64,
305    affine: bool,
306    want_weight_grad: bool,
307    want_bias_grad: bool,
308) -> FerrotorchResult<Option<Vec<Option<Tensor<T>>>>> {
309    let Some(backend) = gpu_backend() else {
310        return Ok(None);
311    };
312    let bc = batch * channels;
313
314    // Tile the per-channel affine `weight[c]` to `[B*C]` (slot b*C+c = w[c]).
315    // Non-affine ⇒ all ones, so the kernel's `grad_scale = invstd`.
316    let weight_host: Vec<f32> = if affine {
317        // `weight` is the (GPU-resident) per-channel affine param of length
318        // `channels`; read it back (tiny) to tile to `[B*C]`. This is the affine
319        // PARAMETER, not gradient data — no grad round trip.
320        let w = weight.data_vec()?;
321        let mut tiled = vec![0.0f32; bc];
322        for b in 0..batch {
323            for c in 0..channels {
324                tiled[b * channels + c] = w[c].to_f32().unwrap();
325            }
326        }
327        tiled
328    } else {
329        vec![1.0f32; bc]
330    };
331    let weight_dev =
332        ferrotorch_core::creation::from_slice::<f32>(&weight_host, &[bc])?.to(input.device())?;
333    // Eval-mode stats are unused (training=true recomputes), but the kernel
334    // needs valid `[B*C]` buffers.
335    let stat_dummy = ferrotorch_core::creation::zeros::<f32>(&[bc])?.to(input.device())?;
336
337    let (gi_h, gw_h, gb_h) = backend.batch_norm_backward_f32(
338        input.gpu_handle()?,
339        grad_output.gpu_handle()?,
340        weight_dev.gpu_handle()?,
341        stat_dummy.gpu_handle()?,
342        stat_dummy.gpu_handle()?,
343        1,  // batch == 1 for the reshaped [1, B*C, S] view
344        bc, // channels == B*C
345        spatial.max(1),
346        eps as f32,
347        true, // InstanceNorm always uses instance stats
348    )?;
349
350    let grad_input = Tensor::from_storage(TensorStorage::gpu(gi_h), input.shape().to_vec(), false)?;
351
352    // Reduce grad_weight / grad_bias from `[B*C]` to `[C]` by summing the
353    // `[B, C]` view over axis 0 — entirely on-device.
354    let reduce =
355        |handle: ferrotorch_core::gpu_dispatch::GpuBufferHandle| -> FerrotorchResult<Tensor<T>> {
356            let summed = backend.sum_axis_f32(&handle, &[batch, channels], 0)?;
357            Tensor::from_storage(TensorStorage::gpu(summed), vec![channels], false)
358        };
359    let grad_weight = if affine && want_weight_grad {
360        Some(reduce(gw_h)?)
361    } else {
362        None
363    };
364    let grad_bias = if affine && want_bias_grad {
365        Some(reduce(gb_h)?)
366    } else {
367        None
368    };
369
370    Ok(Some(vec![Some(grad_input), grad_weight, grad_bias]))
371}
372
373// ===========================================================================
374// LayerNorm
375// ===========================================================================
376
377/// Layer normalization over the last dimension.
378///
379/// Applies the transform:
380///
381/// ```text
382/// y = (x - mean) / sqrt(var + eps) * weight + bias
383/// ```
384///
385/// where `mean` and `var` are computed over the last dimension of the input.
386/// This simplified implementation supports 1-D `normalized_shape` (a single
387/// integer), which is the most common use case (transformer hidden dim).
388///
389/// Matches `torch.nn.LayerNorm` with a single-element `normalized_shape`.
390#[derive(Debug)]
391pub struct LayerNorm<T: Float> {
392    /// The size of the normalized dimension.
393    pub normalized_shape: Vec<usize>,
394    /// Small constant for numerical stability.
395    pub eps: f64,
396    /// Whether to apply learnable affine parameters.
397    pub elementwise_affine: bool,
398    /// Learnable scale (gamma), shape = `normalized_shape`.
399    pub weight: Parameter<T>,
400    /// Learnable shift (beta), shape = `normalized_shape`.
401    pub bias: Parameter<T>,
402    training: bool,
403}
404
405impl<T: Float> LayerNorm<T> {
406    /// Create a new `LayerNorm` layer.
407    ///
408    /// # Arguments
409    ///
410    /// * `normalized_shape` - The shape of the dimensions to normalize over.
411    ///   For the simplified implementation, this should be a single-element
412    ///   slice `[hidden_dim]`.
413    /// * `eps` - Small constant for numerical stability (default: 1e-5).
414    /// * `elementwise_affine` - Whether to include learnable weight and bias.
415    pub fn new(
416        normalized_shape: Vec<usize>,
417        eps: f64,
418        elementwise_affine: bool,
419    ) -> FerrotorchResult<Self> {
420        if normalized_shape.is_empty() {
421            return Err(FerrotorchError::InvalidArgument {
422                message: "normalized_shape must not be empty".into(),
423            });
424        }
425
426        let weight = Parameter::ones(&normalized_shape)?;
427        let bias = Parameter::zeros(&normalized_shape)?;
428
429        Ok(Self {
430            normalized_shape,
431            eps,
432            elementwise_affine,
433            weight,
434            bias,
435            training: true,
436        })
437    }
438}
439
440impl<T: Float> Module<T> for LayerNorm<T> {
441    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
442        let shape = input.shape().to_vec();
443        let ndim = shape.len();
444        let norm_ndim = self.normalized_shape.len();
445
446        if ndim < norm_ndim {
447            return Err(FerrotorchError::ShapeMismatch {
448                message: format!(
449                    "LayerNorm: input has {} dims but normalized_shape has {} dims",
450                    ndim, norm_ndim
451                ),
452            });
453        }
454
455        // Verify that the last N dims of input match normalized_shape.
456        let last_dims = &shape[ndim - norm_ndim..];
457        if last_dims != self.normalized_shape.as_slice() {
458            return Err(FerrotorchError::ShapeMismatch {
459                message: format!(
460                    "LayerNorm: input last dims {:?} don't match normalized_shape {:?}",
461                    last_dims, self.normalized_shape
462                ),
463            });
464        }
465
466        let norm_size: usize = self.normalized_shape.iter().product();
467        let batch_size = input.numel() / norm_size;
468
469        // GPU fast path: native LayerNorm kernel.
470        if input.is_cuda() && self.elementwise_affine {
471            if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
472                let eps_f32 = self.eps as f32;
473                let handle = backend.layernorm_f32(
474                    input.gpu_handle()?,
475                    self.weight.tensor().gpu_handle()?,
476                    self.bias.tensor().gpu_handle()?,
477                    batch_size,
478                    norm_size,
479                    eps_f32,
480                )?;
481                return if is_grad_enabled() && input.requires_grad() {
482                    let grad_fn = Arc::new(LayerNormBackward {
483                        input: input.clone(),
484                        weight: self.weight.tensor().clone(),
485                        bias: self.bias.tensor().clone(),
486                        normalized_shape: self.normalized_shape.clone(),
487                        eps: self.eps,
488                        elementwise_affine: self.elementwise_affine,
489                    });
490                    Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
491                } else {
492                    Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
493                };
494            }
495        }
496
497        // CPU path — CUDA inputs without a GPU backend are rejected above.
498        if input.is_cuda() {
499            return Err(FerrotorchError::NotImplementedOnCuda {
500                op: "LayerNorm::forward",
501            });
502        }
503        let input_data = input.data()?;
504        let eps_t = T::from(self.eps).unwrap();
505        let n_t = T::from(norm_size).unwrap();
506
507        let weight_data = self.weight.tensor().data()?;
508        let bias_data = self.bias.tensor().data()?;
509
510        let mut output = Vec::with_capacity(input.numel());
511
512        for b in 0..batch_size {
513            let start = b * norm_size;
514            let end = start + norm_size;
515            let slice = &input_data[start..end];
516
517            let mean = slice.iter().copied().fold(zero::<T>(), |a, x| a + x) / n_t;
518            let var = slice.iter().copied().fold(zero::<T>(), |a, x| {
519                let d = x - mean;
520                a + d * d
521            }) / n_t;
522            let inv_std = (var + eps_t).sqrt().recip();
523
524            for (i, &x) in slice.iter().enumerate() {
525                let normed = (x - mean) * inv_std;
526                if self.elementwise_affine {
527                    output.push(normed * weight_data[i] + bias_data[i]);
528                } else {
529                    output.push(normed);
530                }
531            }
532        }
533
534        let storage = TensorStorage::cpu(output);
535
536        if is_grad_enabled() && input.requires_grad() {
537            let grad_fn = Arc::new(LayerNormBackward {
538                input: input.clone(),
539                weight: self.weight.tensor().clone(),
540                bias: self.bias.tensor().clone(),
541                normalized_shape: self.normalized_shape.clone(),
542                eps: self.eps,
543                elementwise_affine: self.elementwise_affine,
544            });
545            Tensor::from_operation(storage, shape.to_vec(), grad_fn)
546        } else {
547            Tensor::from_storage(storage, shape.to_vec(), false)
548        }
549    }
550
551    fn parameters(&self) -> Vec<&Parameter<T>> {
552        if self.elementwise_affine {
553            vec![&self.weight, &self.bias]
554        } else {
555            vec![]
556        }
557    }
558
559    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
560        if self.elementwise_affine {
561            vec![&mut self.weight, &mut self.bias]
562        } else {
563            vec![]
564        }
565    }
566
567    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
568        if self.elementwise_affine {
569            vec![
570                ("weight".to_string(), &self.weight),
571                ("bias".to_string(), &self.bias),
572            ]
573        } else {
574            vec![]
575        }
576    }
577
578    fn train(&mut self) {
579        self.training = true;
580    }
581
582    fn eval(&mut self) {
583        self.training = false;
584    }
585
586    fn is_training(&self) -> bool {
587        self.training
588    }
589}
590
591// ---------------------------------------------------------------------------
592// LayerNormBackward
593// ---------------------------------------------------------------------------
594
595/// Backward node for LayerNorm.
596///
597/// Given forward: `y = (x - mean) / std * weight + bias`
598///
599/// The gradients are:
600/// - `d_bias = sum(grad_output, over batch dims)`
601/// - `d_weight = sum(grad_output * x_hat, over batch dims)`
602/// - `d_input`: standard layer norm VJP
603///
604/// Inputs stored: `[input, weight, bias]`.
605/// Returns: `[grad_input, grad_weight, grad_bias]`.
606#[derive(Debug)]
607struct LayerNormBackward<T: Float> {
608    input: Tensor<T>,
609    weight: Tensor<T>,
610    bias: Tensor<T>,
611    normalized_shape: Vec<usize>,
612    eps: f64,
613    elementwise_affine: bool,
614}
615
616impl<T: Float> GradFn<T> for LayerNormBackward<T> {
617    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
618        let norm_size: usize = self.normalized_shape.iter().product();
619        let batch_size = self.input.numel() / norm_size;
620
621        // GPU-native fast path for f32/f64 with elementwise affine
622        if self.input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) && self.elementwise_affine {
623            if let Some(backend) = gpu_backend() {
624                let (gi_h, gw_h, gb_h) = if is_f64::<T>() {
625                    backend.layernorm_backward_f64(
626                        self.input.gpu_handle()?,
627                        grad_output.gpu_handle()?,
628                        self.weight.gpu_handle()?,
629                        batch_size,
630                        norm_size,
631                        self.eps,
632                    )?
633                } else {
634                    backend.layernorm_backward_f32(
635                        self.input.gpu_handle()?,
636                        grad_output.gpu_handle()?,
637                        self.weight.gpu_handle()?,
638                        batch_size,
639                        norm_size,
640                        self.eps as f32,
641                    )?
642                };
643
644                let grad_input_tensor = Tensor::from_storage(
645                    TensorStorage::gpu(gi_h),
646                    self.input.shape().to_vec(),
647                    false,
648                )?;
649
650                let grad_weight_out = if self.weight.requires_grad() {
651                    Some(Tensor::from_storage(
652                        TensorStorage::gpu(gw_h),
653                        self.normalized_shape.clone(),
654                        false,
655                    )?)
656                } else {
657                    None
658                };
659
660                let grad_bias_out = if self.bias.requires_grad() {
661                    Some(Tensor::from_storage(
662                        TensorStorage::gpu(gb_h),
663                        self.normalized_shape.clone(),
664                        false,
665                    )?)
666                } else {
667                    None
668                };
669
670                return Ok(vec![
671                    Some(grad_input_tensor),
672                    grad_weight_out,
673                    grad_bias_out,
674                ]);
675            }
676        }
677
678        // CPU-only path — CUDA inputs without a GPU backend are rejected above.
679        if self.input.is_cuda() {
680            return Err(FerrotorchError::NotImplementedOnCuda {
681                op: "LayerNormBackward",
682            });
683        }
684        let n_t = T::from(norm_size).unwrap();
685        let eps_t = T::from(self.eps).unwrap();
686
687        let input_data = self.input.data()?;
688        let go_data = grad_output.data()?;
689        let weight_data = self.weight.data()?;
690
691        let mut grad_input = vec![zero::<T>(); self.input.numel()];
692        let mut grad_weight = vec![zero::<T>(); norm_size];
693        let mut grad_bias = vec![zero::<T>(); norm_size];
694
695        for b in 0..batch_size {
696            let start = b * norm_size;
697            let end = start + norm_size;
698            let x_slice = &input_data[start..end];
699            let go_slice = &go_data[start..end];
700
701            // Recompute mean and inv_std.
702            let mean = x_slice.iter().copied().fold(zero::<T>(), |a, x| a + x) / n_t;
703            let var = x_slice.iter().copied().fold(zero::<T>(), |a, x| {
704                let d = x - mean;
705                a + d * d
706            }) / n_t;
707            let inv_std = (var + eps_t).sqrt().recip();
708
709            // dl/dx_hat = go * weight (if affine) or go (if not).
710            // Accumulate sums needed for the VJP.
711            let mut dl_dx_hat_sum = zero::<T>();
712            let mut dl_dx_hat_x_hat_sum = zero::<T>();
713
714            for i in 0..norm_size {
715                let x_hat_i = (x_slice[i] - mean) * inv_std;
716                let dl_dx_hat_i = if self.elementwise_affine {
717                    go_slice[i] * weight_data[i]
718                } else {
719                    go_slice[i]
720                };
721
722                dl_dx_hat_sum += dl_dx_hat_i;
723                dl_dx_hat_x_hat_sum += dl_dx_hat_i * x_hat_i;
724
725                if self.elementwise_affine {
726                    grad_weight[i] += go_slice[i] * x_hat_i;
727                    grad_bias[i] += go_slice[i];
728                }
729            }
730
731            // Compute grad_input for this batch element.
732            let dl_dx_hat_mean = dl_dx_hat_sum / n_t;
733            let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / n_t;
734
735            for i in 0..norm_size {
736                let x_hat_i = (x_slice[i] - mean) * inv_std;
737                let dl_dx_hat_i = if self.elementwise_affine {
738                    go_slice[i] * weight_data[i]
739                } else {
740                    go_slice[i]
741                };
742
743                grad_input[start + i] =
744                    inv_std * (dl_dx_hat_i - dl_dx_hat_mean - x_hat_i * dl_dx_hat_x_hat_mean);
745            }
746        }
747
748        let grad_input_tensor = Tensor::from_storage(
749            TensorStorage::cpu(grad_input),
750            self.input.shape().to_vec(),
751            false,
752        )?;
753
754        let grad_weight_out = if self.elementwise_affine && self.weight.requires_grad() {
755            Some(Tensor::from_storage(
756                TensorStorage::cpu(grad_weight),
757                self.normalized_shape.clone(),
758                false,
759            )?)
760        } else {
761            None
762        };
763
764        let grad_bias_out = if self.elementwise_affine && self.bias.requires_grad() {
765            Some(Tensor::from_storage(
766                TensorStorage::cpu(grad_bias),
767                self.normalized_shape.clone(),
768                false,
769            )?)
770        } else {
771            None
772        };
773
774        Ok(vec![
775            Some(grad_input_tensor),
776            grad_weight_out,
777            grad_bias_out,
778        ])
779    }
780
781    fn inputs(&self) -> Vec<&Tensor<T>> {
782        vec![&self.input, &self.weight, &self.bias]
783    }
784
785    fn name(&self) -> &'static str {
786        "LayerNormBackward"
787    }
788}
789
790// ===========================================================================
791// GroupNorm
792// ===========================================================================
793
794/// Group normalization.
795///
796/// Divides channels into groups and normalizes within each group.
797/// For input of shape `[B, C, ...]`, divides `C` channels into `num_groups`
798/// groups of `C / num_groups` channels each, and normalizes the values
799/// within each group (over channels and spatial dimensions).
800///
801/// Matches `torch.nn.GroupNorm`.
802#[derive(Debug)]
803pub struct GroupNorm<T: Float> {
804    /// Number of groups to divide channels into.
805    pub num_groups: usize,
806    /// Number of channels (expected C dimension).
807    pub num_channels: usize,
808    /// Small constant for numerical stability.
809    pub eps: f64,
810    /// Whether to apply learnable affine parameters.
811    pub affine: bool,
812    /// Learnable scale (gamma), shape = `[num_channels]`.
813    pub weight: Parameter<T>,
814    /// Learnable shift (beta), shape = `[num_channels]`.
815    pub bias: Parameter<T>,
816    training: bool,
817}
818
819impl<T: Float> GroupNorm<T> {
820    /// Create a new `GroupNorm` layer.
821    ///
822    /// # Arguments
823    ///
824    /// * `num_groups` - Number of groups to divide channels into.
825    /// * `num_channels` - Number of channels. Must be divisible by `num_groups`.
826    /// * `eps` - Small constant for numerical stability (default: 1e-5).
827    /// * `affine` - Whether to include learnable weight and bias.
828    pub fn new(
829        num_groups: usize,
830        num_channels: usize,
831        eps: f64,
832        affine: bool,
833    ) -> FerrotorchResult<Self> {
834        if num_groups == 0 {
835            return Err(FerrotorchError::InvalidArgument {
836                message: "num_groups must be positive".into(),
837            });
838        }
839        if num_channels == 0 {
840            return Err(FerrotorchError::InvalidArgument {
841                message: "num_channels must be positive".into(),
842            });
843        }
844        if num_channels % num_groups != 0 {
845            return Err(FerrotorchError::InvalidArgument {
846                message: format!(
847                    "num_channels ({num_channels}) must be divisible by num_groups ({num_groups})"
848                ),
849            });
850        }
851
852        let weight = Parameter::ones(&[num_channels])?;
853        let bias = Parameter::zeros(&[num_channels])?;
854
855        Ok(Self {
856            num_groups,
857            num_channels,
858            eps,
859            affine,
860            weight,
861            bias,
862            training: true,
863        })
864    }
865}
866
867impl<T: Float> Module<T> for GroupNorm<T> {
868    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
869        let shape = input.shape().to_vec();
870        if shape.len() < 2 {
871            return Err(FerrotorchError::ShapeMismatch {
872                message: format!(
873                    "GroupNorm: input must have at least 2 dims [B, C, ...], got {:?}",
874                    shape
875                ),
876            });
877        }
878
879        let batch_size = shape[0];
880        let channels = shape[1];
881
882        if channels != self.num_channels {
883            return Err(FerrotorchError::ShapeMismatch {
884                message: format!(
885                    "GroupNorm: expected {} channels, got {}",
886                    self.num_channels, channels
887                ),
888            });
889        }
890
891        let channels_per_group = channels / self.num_groups;
892        // spatial_size = product of dims after C.
893        let spatial_size: usize = shape[2..].iter().product();
894        let spatial = spatial_size.max(1);
895        let group_size = channels_per_group * spatial;
896
897        // GPU fast path: native GroupNorm kernel (#1357). `weight`/`bias`
898        // always have length `num_channels` (ones / zeros when `affine` is
899        // false), so the kernel's unconditional per-channel affine is the
900        // identity in the non-affine case. Mirrors the LayerNorm fast path
901        // and `aten/src/ATen/native/cuda/group_norm_kernel.cu`
902        // `GroupNormKernelImpl`.
903        if input.is_cuda() {
904            if let Some(backend) = gpu_backend() {
905                let eps_f32 = self.eps as f32;
906                let handle = backend.group_norm_f32(
907                    input.gpu_handle()?,
908                    self.weight.tensor().gpu_handle()?,
909                    self.bias.tensor().gpu_handle()?,
910                    batch_size,
911                    channels,
912                    self.num_groups,
913                    spatial,
914                    eps_f32,
915                )?;
916                return if is_grad_enabled() && input.requires_grad() {
917                    let grad_fn = Arc::new(GroupNormBackward {
918                        input: input.clone(),
919                        weight: self.weight.tensor().clone(),
920                        bias: self.bias.tensor().clone(),
921                        num_groups: self.num_groups,
922                        num_channels: self.num_channels,
923                        eps: self.eps,
924                        affine: self.affine,
925                    });
926                    Tensor::from_operation(TensorStorage::gpu(handle), shape.to_vec(), grad_fn)
927                } else {
928                    Tensor::from_storage(TensorStorage::gpu(handle), shape.to_vec(), false)
929                };
930            }
931            // CUDA input without a registered GPU backend: reject honestly.
932            return Err(FerrotorchError::NotImplementedOnCuda {
933                op: "GroupNorm::forward",
934            });
935        }
936        let input_data = input.data()?;
937        let weight_data = self.weight.tensor().data()?;
938        let bias_data = self.bias.tensor().data()?;
939        let eps_t = T::from(self.eps).unwrap();
940        let group_n = T::from(group_size).unwrap();
941
942        let mut output = vec![zero::<T>(); input.numel()];
943
944        for b in 0..batch_size {
945            for g in 0..self.num_groups {
946                let c_start = g * channels_per_group;
947                let c_end = c_start + channels_per_group;
948
949                // Compute mean over the group.
950                let mut sum = zero::<T>();
951                for c in c_start..c_end {
952                    for s in 0..spatial {
953                        let idx = b * channels * spatial + c * spatial + s;
954                        sum += input_data[idx];
955                    }
956                }
957                let mean = sum / group_n;
958
959                // Compute variance over the group.
960                let mut var_sum = zero::<T>();
961                for c in c_start..c_end {
962                    for s in 0..spatial {
963                        let idx = b * channels * spatial + c * spatial + s;
964                        let d = input_data[idx] - mean;
965                        var_sum += d * d;
966                    }
967                }
968                let var = var_sum / group_n;
969                let inv_std = (var + eps_t).sqrt().recip();
970
971                // Normalize and apply per-channel affine.
972                for c in c_start..c_end {
973                    for s in 0..spatial {
974                        let idx = b * channels * spatial + c * spatial + s;
975                        let normed = (input_data[idx] - mean) * inv_std;
976                        if self.affine {
977                            output[idx] = normed * weight_data[c] + bias_data[c];
978                        } else {
979                            output[idx] = normed;
980                        }
981                    }
982                }
983            }
984        }
985
986        let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
987
988        if is_grad_enabled() && input.requires_grad() {
989            let grad_fn = Arc::new(GroupNormBackward {
990                input: input.clone(),
991                weight: self.weight.tensor().clone(),
992                bias: self.bias.tensor().clone(),
993                num_groups: self.num_groups,
994                num_channels: self.num_channels,
995                eps: self.eps,
996                affine: self.affine,
997            });
998            Tensor::from_operation(
999                TensorStorage::cpu(result.data()?.to_vec()),
1000                result.shape().to_vec(),
1001                grad_fn,
1002            )
1003        } else {
1004            Ok(result)
1005        }
1006    }
1007
1008    fn parameters(&self) -> Vec<&Parameter<T>> {
1009        if self.affine {
1010            vec![&self.weight, &self.bias]
1011        } else {
1012            vec![]
1013        }
1014    }
1015
1016    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1017        if self.affine {
1018            vec![&mut self.weight, &mut self.bias]
1019        } else {
1020            vec![]
1021        }
1022    }
1023
1024    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1025        if self.affine {
1026            vec![
1027                ("weight".to_string(), &self.weight),
1028                ("bias".to_string(), &self.bias),
1029            ]
1030        } else {
1031            vec![]
1032        }
1033    }
1034
1035    fn train(&mut self) {
1036        self.training = true;
1037    }
1038
1039    fn eval(&mut self) {
1040        self.training = false;
1041    }
1042
1043    fn is_training(&self) -> bool {
1044        self.training
1045    }
1046}
1047
1048// ---------------------------------------------------------------------------
1049// GroupNormBackward
1050// ---------------------------------------------------------------------------
1051
1052/// Backward node for GroupNorm.
1053///
1054/// Inputs stored: `[input, weight, bias]`.
1055/// Returns: `[grad_input, grad_weight, grad_bias]`.
1056#[derive(Debug)]
1057struct GroupNormBackward<T: Float> {
1058    input: Tensor<T>,
1059    weight: Tensor<T>,
1060    bias: Tensor<T>,
1061    num_groups: usize,
1062    num_channels: usize,
1063    eps: f64,
1064    affine: bool,
1065}
1066
1067impl<T: Float> GradFn<T> for GroupNormBackward<T> {
1068    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1069        let shape = self.input.shape();
1070        let batch_size = shape[0];
1071        let channels = shape[1];
1072        let channels_per_group = channels / self.num_groups;
1073        let spatial_size: usize = shape[2..].iter().product();
1074        let spatial = spatial_size.max(1);
1075        let group_size = channels_per_group * spatial;
1076        let group_n = T::from(group_size).unwrap();
1077        let eps_t = T::from(self.eps).unwrap();
1078
1079        if self.input.is_cuda() {
1080            return Err(FerrotorchError::NotImplementedOnCuda {
1081                op: "GroupNormBackward",
1082            });
1083        }
1084        let input_data = self.input.data()?;
1085        let go_data = grad_output.data()?;
1086        let weight_data = self.weight.data()?;
1087
1088        let mut grad_input = vec![zero::<T>(); self.input.numel()];
1089        let mut grad_weight = vec![zero::<T>(); self.num_channels];
1090        let mut grad_bias = vec![zero::<T>(); self.num_channels];
1091
1092        for b in 0..batch_size {
1093            for g in 0..self.num_groups {
1094                let c_start = g * channels_per_group;
1095                let c_end = c_start + channels_per_group;
1096
1097                // Recompute mean and inv_std for this group.
1098                let mut sum = zero::<T>();
1099                for c in c_start..c_end {
1100                    for s in 0..spatial {
1101                        let idx = b * channels * spatial + c * spatial + s;
1102                        sum += input_data[idx];
1103                    }
1104                }
1105                let mean = sum / group_n;
1106
1107                let mut var_sum = zero::<T>();
1108                for c in c_start..c_end {
1109                    for s in 0..spatial {
1110                        let idx = b * channels * spatial + c * spatial + s;
1111                        let d = input_data[idx] - mean;
1112                        var_sum += d * d;
1113                    }
1114                }
1115                let var = var_sum / group_n;
1116                let inv_std = (var + eps_t).sqrt().recip();
1117
1118                // Compute sums for the VJP.
1119                let mut dl_dx_hat_sum = zero::<T>();
1120                let mut dl_dx_hat_x_hat_sum = zero::<T>();
1121
1122                for c in c_start..c_end {
1123                    for s in 0..spatial {
1124                        let idx = b * channels * spatial + c * spatial + s;
1125                        let x_hat = (input_data[idx] - mean) * inv_std;
1126                        let dl_dx_hat = if self.affine {
1127                            go_data[idx] * weight_data[c]
1128                        } else {
1129                            go_data[idx]
1130                        };
1131                        dl_dx_hat_sum += dl_dx_hat;
1132                        dl_dx_hat_x_hat_sum += dl_dx_hat * x_hat;
1133
1134                        if self.affine {
1135                            grad_weight[c] += go_data[idx] * x_hat;
1136                            grad_bias[c] += go_data[idx];
1137                        }
1138                    }
1139                }
1140
1141                let dl_dx_hat_mean = dl_dx_hat_sum / group_n;
1142                let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / group_n;
1143
1144                for (ci, &wd) in weight_data[c_start..c_end].iter().enumerate() {
1145                    let c = c_start + ci;
1146                    for s in 0..spatial {
1147                        let idx = b * channels * spatial + c * spatial + s;
1148                        let x_hat = (input_data[idx] - mean) * inv_std;
1149                        let dl_dx_hat = if self.affine {
1150                            go_data[idx] * wd
1151                        } else {
1152                            go_data[idx]
1153                        };
1154                        grad_input[idx] =
1155                            inv_std * (dl_dx_hat - dl_dx_hat_mean - x_hat * dl_dx_hat_x_hat_mean);
1156                    }
1157                }
1158            }
1159        }
1160
1161        let grad_input_tensor = Tensor::from_storage(
1162            TensorStorage::cpu(grad_input),
1163            self.input.shape().to_vec(),
1164            false,
1165        )?;
1166
1167        let grad_weight_out = if self.affine && self.weight.requires_grad() {
1168            Some(Tensor::from_storage(
1169                TensorStorage::cpu(grad_weight),
1170                vec![self.num_channels],
1171                false,
1172            )?)
1173        } else {
1174            None
1175        };
1176
1177        let grad_bias_out = if self.affine && self.bias.requires_grad() {
1178            Some(Tensor::from_storage(
1179                TensorStorage::cpu(grad_bias),
1180                vec![self.num_channels],
1181                false,
1182            )?)
1183        } else {
1184            None
1185        };
1186
1187        Ok(vec![
1188            Some(grad_input_tensor),
1189            grad_weight_out,
1190            grad_bias_out,
1191        ])
1192    }
1193
1194    fn inputs(&self) -> Vec<&Tensor<T>> {
1195        vec![&self.input, &self.weight, &self.bias]
1196    }
1197
1198    fn name(&self) -> &'static str {
1199        "GroupNormBackward"
1200    }
1201}
1202
1203// ===========================================================================
1204// RMSNorm
1205// ===========================================================================
1206
1207/// Root Mean Square Layer Normalization.
1208///
1209/// Applies the transform:
1210///
1211/// ```text
1212/// y = x / sqrt(mean(x^2) + eps) * weight
1213/// ```
1214///
1215/// Unlike LayerNorm, RMSNorm does not center the input (no mean subtraction)
1216/// and has no bias parameter. This makes it slightly faster and is used in
1217/// many modern transformer architectures (LLaMA, Gemma, etc.).
1218///
1219/// Matches the RMSNorm formulation from "Root Mean Square Layer Normalization"
1220/// (Zhang & Sennrich, 2019).
1221#[derive(Debug)]
1222pub struct RMSNorm<T: Float> {
1223    /// The size of the normalized dimension.
1224    pub normalized_shape: Vec<usize>,
1225    /// Small constant for numerical stability.
1226    pub eps: f64,
1227    /// Learnable scale (gamma), shape = `normalized_shape`.
1228    pub weight: Parameter<T>,
1229    training: bool,
1230}
1231
1232impl<T: Float> RMSNorm<T> {
1233    /// Create a new `RMSNorm` layer.
1234    ///
1235    /// # Arguments
1236    ///
1237    /// * `normalized_shape` - The shape of the dimensions to normalize over.
1238    /// * `eps` - Small constant for numerical stability (default: 1e-5).
1239    pub fn new(normalized_shape: Vec<usize>, eps: f64) -> FerrotorchResult<Self> {
1240        if normalized_shape.is_empty() {
1241            return Err(FerrotorchError::InvalidArgument {
1242                message: "normalized_shape must not be empty".into(),
1243            });
1244        }
1245
1246        let weight = Parameter::ones(&normalized_shape)?;
1247
1248        Ok(Self {
1249            normalized_shape,
1250            eps,
1251            weight,
1252            training: true,
1253        })
1254    }
1255}
1256
1257impl<T: Float> Module<T> for RMSNorm<T> {
1258    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1259        let shape = input.shape().to_vec();
1260        let ndim = shape.len();
1261        let norm_ndim = self.normalized_shape.len();
1262
1263        if ndim < norm_ndim {
1264            return Err(FerrotorchError::ShapeMismatch {
1265                message: format!(
1266                    "RMSNorm: input has {} dims but normalized_shape has {} dims",
1267                    ndim, norm_ndim
1268                ),
1269            });
1270        }
1271
1272        let last_dims = &shape[ndim - norm_ndim..];
1273        if last_dims != self.normalized_shape.as_slice() {
1274            return Err(FerrotorchError::ShapeMismatch {
1275                message: format!(
1276                    "RMSNorm: input last dims {:?} don't match normalized_shape {:?}",
1277                    last_dims, self.normalized_shape
1278                ),
1279            });
1280        }
1281
1282        let norm_size: usize = self.normalized_shape.iter().product();
1283        let batch_size = input.numel() / norm_size;
1284
1285        // GPU fast path: native RMSNorm kernel.
1286        if input.is_cuda() {
1287            if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
1288                let eps_f32 = self.eps as f32;
1289                let handle = backend.rmsnorm_f32(
1290                    input.gpu_handle()?,
1291                    self.weight.tensor().gpu_handle()?,
1292                    batch_size,
1293                    norm_size,
1294                    eps_f32,
1295                )?;
1296                return if is_grad_enabled() && input.requires_grad() {
1297                    let grad_fn = Arc::new(RMSNormBackward {
1298                        input: input.clone(),
1299                        weight: self.weight.tensor().clone(),
1300                        normalized_shape: self.normalized_shape.clone(),
1301                        eps: self.eps,
1302                    });
1303                    Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
1304                } else {
1305                    Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
1306                };
1307            }
1308        }
1309
1310        // CPU path — CUDA inputs without a GPU backend are rejected above.
1311        if input.is_cuda() {
1312            return Err(FerrotorchError::NotImplementedOnCuda {
1313                op: "RMSNorm::forward",
1314            });
1315        }
1316        let input_data = input.data()?;
1317        let weight_data = self.weight.tensor().data()?;
1318        let eps_t = T::from(self.eps).unwrap();
1319        let n_t = T::from(norm_size).unwrap();
1320
1321        // bf16 has a 7-bit mantissa; a mean-of-squares over hundreds of
1322        // elements saturates the accumulator and collapses into near-
1323        // constant outputs. Detect bf16 and promote the accumulator
1324        // (and the eps / normalization) to f32.
1325        let is_bf16 = std::any::TypeId::of::<T>() == std::any::TypeId::of::<half::bf16>();
1326        let mut output = Vec::with_capacity(input.numel());
1327
1328        for b in 0..batch_size {
1329            let start = b * norm_size;
1330            let end = start + norm_size;
1331            let slice = &input_data[start..end];
1332
1333            if is_bf16 {
1334                // f32 accumulator path for bf16.
1335                let eps_f32 = self.eps as f32;
1336                let n_f32 = norm_size as f32;
1337                let mut sum_sq = 0.0f32;
1338                for &x in slice {
1339                    let xf = x.to_f32().unwrap();
1340                    sum_sq += xf * xf;
1341                }
1342                let inv_rms_f32 = 1.0f32 / ((sum_sq / n_f32) + eps_f32).sqrt();
1343                let inv_rms = T::from(inv_rms_f32).unwrap();
1344                for (i, &x) in slice.iter().enumerate() {
1345                    output.push(x * inv_rms * weight_data[i]);
1346                }
1347            } else {
1348                // rms = sqrt(mean(x^2) + eps)
1349                let mean_sq = slice.iter().copied().fold(zero::<T>(), |a, x| a + x * x) / n_t;
1350                let rms = (mean_sq + eps_t).sqrt();
1351                let inv_rms = rms.recip();
1352
1353                for (i, &x) in slice.iter().enumerate() {
1354                    output.push(x * inv_rms * weight_data[i]);
1355                }
1356            }
1357        }
1358
1359        let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
1360
1361        if is_grad_enabled() && input.requires_grad() {
1362            let grad_fn = Arc::new(RMSNormBackward {
1363                input: input.clone(),
1364                weight: self.weight.tensor().clone(),
1365                normalized_shape: self.normalized_shape.clone(),
1366                eps: self.eps,
1367            });
1368            Tensor::from_operation(
1369                TensorStorage::cpu(result.data()?.to_vec()),
1370                result.shape().to_vec(),
1371                grad_fn,
1372            )
1373        } else {
1374            Ok(result)
1375        }
1376    }
1377
1378    fn parameters(&self) -> Vec<&Parameter<T>> {
1379        vec![&self.weight]
1380    }
1381
1382    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1383        vec![&mut self.weight]
1384    }
1385
1386    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1387        vec![("weight".to_string(), &self.weight)]
1388    }
1389
1390    fn train(&mut self) {
1391        self.training = true;
1392    }
1393
1394    fn eval(&mut self) {
1395        self.training = false;
1396    }
1397
1398    fn is_training(&self) -> bool {
1399        self.training
1400    }
1401}
1402
1403// ---------------------------------------------------------------------------
1404// RMSNormBackward
1405// ---------------------------------------------------------------------------
1406
1407/// Backward node for RMSNorm.
1408///
1409/// Forward: `y = x / rms * weight` where `rms = sqrt(mean(x^2) + eps)`.
1410///
1411/// Let `s = 1/rms`. Then `y_i = x_i * s * w_i`.
1412///
1413/// `ds/dx_j = -x_j / (n * rms^3)`
1414///
1415/// `dy_i/dx_j = delta_ij * s * w_i + x_i * w_i * ds/dx_j`
1416///            = `delta_ij * s * w_i - x_i * w_i * x_j / (n * rms^3)`
1417///
1418/// `grad_x_j = sum_i go_i * dy_i/dx_j`
1419///           = `go_j * s * w_j - (1/(n * rms^3)) * x_j * sum_i(go_i * x_i * w_i)`
1420///           = `s * (go_j * w_j - x_j * s^2 * mean(go * x * w))`
1421///
1422/// Inputs stored: `[input, weight]`.
1423/// Returns: `[grad_input, grad_weight]`.
1424#[derive(Debug)]
1425struct RMSNormBackward<T: Float> {
1426    input: Tensor<T>,
1427    weight: Tensor<T>,
1428    normalized_shape: Vec<usize>,
1429    eps: f64,
1430}
1431
1432impl<T: Float> GradFn<T> for RMSNormBackward<T> {
1433    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1434        let norm_size: usize = self.normalized_shape.iter().product();
1435        let batch_size = self.input.numel() / norm_size;
1436
1437        // GPU-native fast path for f32/f64
1438        if self.input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
1439            if let Some(backend) = gpu_backend() {
1440                let (gi_h, gw_h) = if is_f64::<T>() {
1441                    backend.rmsnorm_backward_f64(
1442                        self.input.gpu_handle()?,
1443                        grad_output.gpu_handle()?,
1444                        self.weight.gpu_handle()?,
1445                        batch_size,
1446                        norm_size,
1447                        self.eps,
1448                    )?
1449                } else {
1450                    backend.rmsnorm_backward_f32(
1451                        self.input.gpu_handle()?,
1452                        grad_output.gpu_handle()?,
1453                        self.weight.gpu_handle()?,
1454                        batch_size,
1455                        norm_size,
1456                        self.eps as f32,
1457                    )?
1458                };
1459
1460                let grad_input_tensor = Tensor::from_storage(
1461                    TensorStorage::gpu(gi_h),
1462                    self.input.shape().to_vec(),
1463                    false,
1464                )?;
1465
1466                let grad_weight_out = if self.weight.requires_grad() {
1467                    Some(Tensor::from_storage(
1468                        TensorStorage::gpu(gw_h),
1469                        self.normalized_shape.clone(),
1470                        false,
1471                    )?)
1472                } else {
1473                    None
1474                };
1475
1476                return Ok(vec![Some(grad_input_tensor), grad_weight_out]);
1477            }
1478        }
1479
1480        // CPU-only path — CUDA inputs without a GPU backend are rejected above.
1481        if self.input.is_cuda() {
1482            return Err(FerrotorchError::NotImplementedOnCuda {
1483                op: "RMSNormBackward",
1484            });
1485        }
1486        let n_t = T::from(norm_size).unwrap();
1487        let eps_t = T::from(self.eps).unwrap();
1488
1489        let input_data = self.input.data()?;
1490        let go_data = grad_output.data()?;
1491        let weight_data = self.weight.data()?;
1492
1493        let mut grad_input = vec![zero::<T>(); self.input.numel()];
1494        let mut grad_weight = vec![zero::<T>(); norm_size];
1495
1496        for b in 0..batch_size {
1497            let start = b * norm_size;
1498            let end = start + norm_size;
1499            let x_slice = &input_data[start..end];
1500            let go_slice = &go_data[start..end];
1501
1502            // Recompute rms.
1503            let mean_sq = x_slice.iter().copied().fold(zero::<T>(), |a, x| a + x * x) / n_t;
1504            let rms = (mean_sq + eps_t).sqrt();
1505            let inv_rms = rms.recip();
1506            let inv_rms_sq = inv_rms * inv_rms;
1507
1508            // sum_i(go_i * x_i * w_i) / n
1509            let go_x_w_mean = x_slice
1510                .iter()
1511                .zip(go_slice.iter())
1512                .zip(weight_data.iter())
1513                .fold(zero::<T>(), |a, ((&x, &go), &w)| a + go * x * w)
1514                / n_t;
1515
1516            for i in 0..norm_size {
1517                // grad_x_j = inv_rms * (go_j * w_j - x_j * inv_rms^2 * go_x_w_mean)
1518                grad_input[start + i] = inv_rms
1519                    * (go_slice[i] * weight_data[i] - x_slice[i] * inv_rms_sq * go_x_w_mean);
1520
1521                // grad_weight_i += go_i * x_i * inv_rms
1522                grad_weight[i] += go_slice[i] * x_slice[i] * inv_rms;
1523            }
1524        }
1525
1526        let grad_input_tensor = Tensor::from_storage(
1527            TensorStorage::cpu(grad_input),
1528            self.input.shape().to_vec(),
1529            false,
1530        )?;
1531
1532        let grad_weight_out = if self.weight.requires_grad() {
1533            Some(Tensor::from_storage(
1534                TensorStorage::cpu(grad_weight),
1535                self.normalized_shape.clone(),
1536                false,
1537            )?)
1538        } else {
1539            None
1540        };
1541
1542        Ok(vec![Some(grad_input_tensor), grad_weight_out])
1543    }
1544
1545    fn inputs(&self) -> Vec<&Tensor<T>> {
1546        vec![&self.input, &self.weight]
1547    }
1548
1549    fn name(&self) -> &'static str {
1550        "RMSNormBackward"
1551    }
1552}
1553
1554// ===========================================================================
1555// BatchNorm2d
1556// ===========================================================================
1557
1558/// Batch normalization over 4D inputs (a mini-batch of 2D inputs with an
1559/// additional channel dimension).
1560///
1561/// Applies the transform per channel:
1562///
1563/// ```text
1564/// y = (x - mean) / sqrt(var + eps) * weight + bias
1565/// ```
1566///
1567/// During **training**, `mean` and `var` are computed from the current
1568/// mini-batch over the `(B, H, W)` dimensions, and exponential moving
1569/// averages of these statistics are maintained in `running_mean` and
1570/// `running_var`.
1571///
1572/// During **evaluation**, the accumulated `running_mean` and `running_var`
1573/// are used instead of batch statistics.
1574///
1575/// Matches `torch.nn.BatchNorm2d`.
1576pub struct BatchNorm2d<T: Float> {
1577    /// Number of channels (features) `C`.
1578    pub num_features: usize,
1579    /// Small constant for numerical stability.
1580    pub eps: f64,
1581    /// Momentum for the running mean / variance update
1582    /// (`running = (1 - momentum) * running + momentum * batch`).
1583    pub momentum: f64,
1584    /// Whether to apply a learnable affine transform.
1585    pub affine: bool,
1586    /// Learnable scale (gamma), shape `[C]`. `None` when `affine == false`.
1587    pub weight: Option<Parameter<T>>,
1588    /// Learnable shift (beta), shape `[C]`. `None` when `affine == false`.
1589    pub bias: Option<Parameter<T>>,
1590    /// Exponential moving average of per-channel means.
1591    /// Uses `Mutex` for interior mutability because `Module::forward` takes `&self`
1592    /// and `Module` requires `Send + Sync`.
1593    running_mean: Mutex<Vec<f64>>,
1594    /// Exponential moving average of per-channel variances.
1595    running_var: Mutex<Vec<f64>>,
1596    /// Number of forward calls in training mode (for tracking).
1597    num_batches_tracked: Mutex<usize>,
1598    /// Whether the layer is in training mode.
1599    training: Mutex<bool>,
1600}
1601
1602// Manual Debug because Mutex doesn't derive Debug in all contexts nicely.
1603impl<T: Float> std::fmt::Debug for BatchNorm2d<T> {
1604    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1605        f.debug_struct("BatchNorm2d")
1606            .field("num_features", &self.num_features)
1607            .field("eps", &self.eps)
1608            .field("momentum", &self.momentum)
1609            .field("affine", &self.affine)
1610            .field("weight", &self.weight)
1611            .field("bias", &self.bias)
1612            .field("training", &self.training)
1613            .finish()
1614    }
1615}
1616
1617impl<T: Float> BatchNorm2d<T> {
1618    /// Create a new `BatchNorm2d` layer.
1619    ///
1620    /// # Arguments
1621    ///
1622    /// * `num_features` - Number of channels `C`.
1623    /// * `eps` - Numerical stability constant (default: `1e-5`).
1624    /// * `momentum` - Running-statistics momentum (default: `0.1`).
1625    /// * `affine` - Whether to include learnable weight and bias.
1626    pub fn new(
1627        num_features: usize,
1628        eps: f64,
1629        momentum: f64,
1630        affine: bool,
1631    ) -> FerrotorchResult<Self> {
1632        if num_features == 0 {
1633            return Err(FerrotorchError::InvalidArgument {
1634                message: "num_features must be positive".into(),
1635            });
1636        }
1637
1638        let weight = if affine {
1639            Some(Parameter::ones(&[num_features])?)
1640        } else {
1641            None
1642        };
1643
1644        let bias = if affine {
1645            Some(Parameter::zeros(&[num_features])?)
1646        } else {
1647            None
1648        };
1649
1650        Ok(Self {
1651            num_features,
1652            eps,
1653            momentum,
1654            affine,
1655            weight,
1656            bias,
1657            running_mean: Mutex::new(vec![0.0; num_features]),
1658            running_var: Mutex::new(vec![1.0; num_features]),
1659            num_batches_tracked: Mutex::new(0),
1660            training: Mutex::new(true),
1661        })
1662    }
1663
1664    /// Access the current running mean (snapshot copy).
1665    pub fn running_mean(&self) -> Vec<f64> {
1666        self.running_mean.lock().unwrap().clone()
1667    }
1668
1669    /// Access the current running variance (snapshot copy).
1670    pub fn running_var(&self) -> Vec<f64> {
1671        self.running_var.lock().unwrap().clone()
1672    }
1673
1674    /// Number of training batches tracked so far.
1675    pub fn num_batches_tracked(&self) -> usize {
1676        *self.num_batches_tracked.lock().unwrap()
1677    }
1678
1679    /// Set the running mean from a slice of length [`num_features`].
1680    ///
1681    /// Used to load `running_mean` from a state dict (#984). Validates:
1682    ///
1683    /// 1. `value.len() == num_features` (else [`FerrotorchError::ShapeMismatch`]).
1684    /// 2. Every entry is finite (else [`FerrotorchError::InvalidArgument`]).
1685    ///
1686    /// Storage is `Mutex<Vec<f64>>` for numerical stability across BN's
1687    /// running-stat update; entries are widened from `T` to `f64` here
1688    /// and narrowed back to `T` when the eval-mode forward path reads
1689    /// them. The same `Mutex` used by the forward path is acquired so
1690    /// concurrent forward + setter remains safe.
1691    ///
1692    /// Round-trip: after `set_running_mean(&v)`,
1693    /// [`running_mean`](Self::running_mean) returns a `Vec<f64>` whose
1694    /// elements equal `v[i]` widened to `f64`.
1695    ///
1696    /// [`num_features`]: Self::num_features
1697    pub fn set_running_mean(&self, value: &[T]) -> FerrotorchResult<()> {
1698        if value.len() != self.num_features {
1699            return Err(FerrotorchError::ShapeMismatch {
1700                message: format!(
1701                    "BatchNorm2d::set_running_mean: expected slice of length \
1702                     num_features={}, got {}",
1703                    self.num_features,
1704                    value.len()
1705                ),
1706            });
1707        }
1708        for (i, x) in value.iter().enumerate() {
1709            if !num_traits::Float::is_finite(*x) {
1710                return Err(FerrotorchError::InvalidArgument {
1711                    message: format!(
1712                        "BatchNorm2d::set_running_mean: non-finite value at \
1713                         index {i} (running_mean must be finite)"
1714                    ),
1715                });
1716            }
1717        }
1718        let mut rm = self.running_mean.lock().unwrap();
1719        for (slot, x) in rm.iter_mut().zip(value.iter()) {
1720            *slot = x.to_f64().unwrap();
1721        }
1722        Ok(())
1723    }
1724
1725    /// Set the running variance from a slice of length [`num_features`].
1726    ///
1727    /// Used to load `running_var` from a state dict (#984). Validates:
1728    ///
1729    /// 1. `value.len() == num_features` (else [`FerrotorchError::ShapeMismatch`]).
1730    /// 2. Every entry is finite (else [`FerrotorchError::InvalidArgument`]).
1731    /// 3. Every entry is non-negative — variance is by definition `>= 0`.
1732    ///    A negative entry yields a `NaN` `inv_std` in the forward path
1733    ///    and silently corrupts downstream activations; rejecting at
1734    ///    the API boundary surfaces the bug instead.
1735    ///
1736    /// Same `Mutex<Vec<f64>>` storage and `T → f64` widening as
1737    /// [`set_running_mean`](Self::set_running_mean).
1738    ///
1739    /// [`num_features`]: Self::num_features
1740    pub fn set_running_var(&self, value: &[T]) -> FerrotorchResult<()> {
1741        if value.len() != self.num_features {
1742            return Err(FerrotorchError::ShapeMismatch {
1743                message: format!(
1744                    "BatchNorm2d::set_running_var: expected slice of length \
1745                     num_features={}, got {}",
1746                    self.num_features,
1747                    value.len()
1748                ),
1749            });
1750        }
1751        let zero_t = zero::<T>();
1752        for (i, x) in value.iter().enumerate() {
1753            if !num_traits::Float::is_finite(*x) {
1754                return Err(FerrotorchError::InvalidArgument {
1755                    message: format!(
1756                        "BatchNorm2d::set_running_var: non-finite value at \
1757                         index {i} (running_var must be finite)"
1758                    ),
1759                });
1760            }
1761            if *x < zero_t {
1762                return Err(FerrotorchError::InvalidArgument {
1763                    message: format!(
1764                        "BatchNorm2d::set_running_var: negative value {} at \
1765                         index {i} (running_var must be non-negative)",
1766                        x.to_f64().unwrap()
1767                    ),
1768                });
1769            }
1770        }
1771        let mut rv = self.running_var.lock().unwrap();
1772        for (slot, x) in rv.iter_mut().zip(value.iter()) {
1773            *slot = x.to_f64().unwrap();
1774        }
1775        Ok(())
1776    }
1777
1778    /// Set the number of training batches tracked. (#984)
1779    ///
1780    /// Used to load `num_batches_tracked` from a state dict. The value
1781    /// is a non-negative integer in PyTorch's reference impl; this
1782    /// setter accepts any `usize` and writes through the same `Mutex`
1783    /// the forward path uses.
1784    pub fn set_num_batches_tracked(&self, value: usize) -> FerrotorchResult<()> {
1785        let mut nbt = self.num_batches_tracked.lock().unwrap();
1786        *nbt = value;
1787        Ok(())
1788    }
1789}
1790
1791impl<T: Float> Module<T> for BatchNorm2d<T> {
1792    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1793        let shape = input.shape().to_vec();
1794        if shape.len() != 4 {
1795            return Err(FerrotorchError::ShapeMismatch {
1796                message: format!(
1797                    "BatchNorm2d: expected 4D input [B, C, H, W], got {:?}",
1798                    shape
1799                ),
1800            });
1801        }
1802
1803        let batch = shape[0];
1804        let channels = shape[1];
1805        let height = shape[2];
1806        let width = shape[3];
1807        let spatial = height * width;
1808
1809        if channels != self.num_features {
1810            return Err(FerrotorchError::ShapeMismatch {
1811                message: format!(
1812                    "BatchNorm2d: expected {} channels, got {}",
1813                    self.num_features, channels
1814                ),
1815            });
1816        }
1817
1818        // torch `_verify_batch_size` guard (`torch/nn/functional.py:2811-2814`):
1819        // in training mode the per-channel element count (numel / channels =
1820        // batch * spatial) must exceed 1 — variance is undefined with a single
1821        // sample. Fires BEFORE the CPU/GPU dispatch so both paths reject
1822        // identically (#1558).
1823        if *self.training.lock().unwrap() && batch * spatial <= 1 {
1824            return Err(FerrotorchError::InvalidArgument {
1825                message: format!(
1826                    "Expected more than 1 value per channel when training, got input size {:?}",
1827                    shape
1828                ),
1829            });
1830        }
1831
1832        // GPU fast path (#1449): per-channel normalize over (B, H, W). f32-only
1833        // (the kernel is f32); running-stat update mirrors the CPU branch.
1834        if input.is_cuda() {
1835            if is_f32::<T>() {
1836                let is_training = *self.training.lock().unwrap();
1837                if let Some(handle) = batch_norm_gpu_forward(
1838                    input,
1839                    self.weight.as_ref().map(|w| w.tensor()),
1840                    self.bias.as_ref().map(|b| b.tensor()),
1841                    &self.running_mean,
1842                    &self.running_var,
1843                    &self.num_batches_tracked,
1844                    self.momentum,
1845                    self.eps,
1846                    channels,
1847                    spatial,
1848                    is_training,
1849                )? {
1850                    return if is_grad_enabled() && input.requires_grad() {
1851                        let grad_fn = Arc::new(BatchNorm2dBackward {
1852                            input: input.clone(),
1853                            x_hat: Tensor::from_storage(
1854                                TensorStorage::cpu(Vec::new()),
1855                                vec![0],
1856                                false,
1857                            )?,
1858                            weight: self.weight.as_ref().map(|w| w.tensor().clone()),
1859                            bias: self.bias.as_ref().map(|b| b.tensor().clone()),
1860                            chan_var: Vec::new(),
1861                            eps: self.eps,
1862                            affine: self.affine,
1863                            is_training,
1864                            running_mean: self.running_mean.lock().unwrap().clone(),
1865                            running_var: self.running_var.lock().unwrap().clone(),
1866                        });
1867                        Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
1868                    } else {
1869                        Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
1870                    };
1871                }
1872            }
1873            return Err(FerrotorchError::NotImplementedOnCuda {
1874                op: "BatchNorm2d::forward",
1875            });
1876        }
1877        let input_data = input.data()?;
1878        let eps_t = T::from(self.eps).unwrap();
1879
1880        let weight_data = self.weight.as_ref().map(|w| w.tensor().data().unwrap());
1881        let bias_data = self.bias.as_ref().map(|b| b.tensor().data().unwrap());
1882
1883        let is_training = *self.training.lock().unwrap();
1884
1885        // Per-channel mean and variance (as T for computation).
1886        let mut chan_mean = vec![zero::<T>(); channels];
1887        let mut chan_var = vec![zero::<T>(); channels];
1888
1889        if is_training {
1890            // Compute batch statistics over (B, H, W).
1891            let count = batch * spatial;
1892            let count_t = T::from(count).unwrap();
1893
1894            for c in 0..channels {
1895                let mut sum = zero::<T>();
1896                for b in 0..batch {
1897                    let base = b * channels * spatial + c * spatial;
1898                    for s in 0..spatial {
1899                        sum += input_data[base + s];
1900                    }
1901                }
1902                chan_mean[c] = sum / count_t;
1903
1904                let mut var_sum = zero::<T>();
1905                for b in 0..batch {
1906                    let base = b * channels * spatial + c * spatial;
1907                    for s in 0..spatial {
1908                        let d = input_data[base + s] - chan_mean[c];
1909                        var_sum += d * d;
1910                    }
1911                }
1912                // Biased variance (like PyTorch).
1913                chan_var[c] = var_sum / count_t;
1914            }
1915
1916            // Update running statistics.
1917            {
1918                let mut rm = self.running_mean.lock().unwrap();
1919                let mut rv = self.running_var.lock().unwrap();
1920                let mut nbt = self.num_batches_tracked.lock().unwrap();
1921                *nbt += 1;
1922
1923                let mom = self.momentum;
1924                // For running_var, PyTorch uses unbiased (Bessel-corrected)
1925                // variance in the running update.
1926                let bessel = if count > 1 {
1927                    count as f64 / (count as f64 - 1.0)
1928                } else {
1929                    1.0
1930                };
1931
1932                for c in 0..channels {
1933                    let batch_mean_f64 = chan_mean[c].to_f64().unwrap();
1934                    let batch_var_f64 = chan_var[c].to_f64().unwrap();
1935
1936                    rm[c] = (1.0 - mom) * rm[c] + mom * batch_mean_f64;
1937                    rv[c] = (1.0 - mom) * rv[c] + mom * batch_var_f64 * bessel;
1938                }
1939            }
1940        } else {
1941            // Eval mode: use running statistics.
1942            let rm = self.running_mean.lock().unwrap();
1943            let rv = self.running_var.lock().unwrap();
1944
1945            for c in 0..channels {
1946                chan_mean[c] = T::from(rm[c]).unwrap();
1947                chan_var[c] = T::from(rv[c]).unwrap();
1948            }
1949        }
1950
1951        // Normalize and optionally scale/shift.
1952        let mut output = vec![zero::<T>(); input.numel()];
1953
1954        // Pre-compute inv_std per channel.
1955        let mut inv_std = vec![zero::<T>(); channels];
1956        // Also store x_hat for the backward pass if needed.
1957        let mut x_hat_data = if is_grad_enabled() && input.requires_grad() {
1958            Vec::with_capacity(input.numel())
1959        } else {
1960            Vec::new()
1961        };
1962        let need_x_hat = is_grad_enabled() && input.requires_grad();
1963
1964        for c in 0..channels {
1965            inv_std[c] = (chan_var[c] + eps_t).sqrt().recip();
1966        }
1967
1968        for b in 0..batch {
1969            for c in 0..channels {
1970                let base = b * channels * spatial + c * spatial;
1971                for s in 0..spatial {
1972                    let idx = base + s;
1973                    let normed = (input_data[idx] - chan_mean[c]) * inv_std[c];
1974
1975                    if need_x_hat {
1976                        x_hat_data.push(normed);
1977                    }
1978
1979                    if self.affine {
1980                        let w = weight_data.as_ref().unwrap();
1981                        let bi = bias_data.as_ref().unwrap();
1982                        output[idx] = normed * w[c] + bi[c];
1983                    } else {
1984                        output[idx] = normed;
1985                    }
1986                }
1987            }
1988        }
1989
1990        let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
1991
1992        if is_grad_enabled() && input.requires_grad() {
1993            let weight_tensor = self.weight.as_ref().map(|w| w.tensor().clone());
1994            let bias_tensor = self.bias.as_ref().map(|b| b.tensor().clone());
1995
1996            let grad_fn = Arc::new(BatchNorm2dBackward {
1997                input: input.clone(),
1998                x_hat: Tensor::from_storage(TensorStorage::cpu(x_hat_data), shape.to_vec(), false)?,
1999                weight: weight_tensor,
2000                bias: bias_tensor,
2001                chan_var: chan_var.iter().map(|v| v.to_f64().unwrap()).collect(),
2002                eps: self.eps,
2003                affine: self.affine,
2004                is_training,
2005                running_mean: self.running_mean.lock().unwrap().clone(),
2006                running_var: self.running_var.lock().unwrap().clone(),
2007            });
2008
2009            Tensor::from_operation(
2010                TensorStorage::cpu(result.data()?.to_vec()),
2011                result.shape().to_vec(),
2012                grad_fn,
2013            )
2014        } else {
2015            Ok(result)
2016        }
2017    }
2018
2019    fn parameters(&self) -> Vec<&Parameter<T>> {
2020        match (&self.weight, &self.bias) {
2021            (Some(w), Some(b)) => vec![w, b],
2022            _ => vec![],
2023        }
2024    }
2025
2026    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2027        match (&mut self.weight, &mut self.bias) {
2028            (Some(w), Some(b)) => vec![w, b],
2029            _ => vec![],
2030        }
2031    }
2032
2033    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2034        match (&self.weight, &self.bias) {
2035            (Some(w), Some(b)) => vec![("weight".to_string(), w), ("bias".to_string(), b)],
2036            _ => vec![],
2037        }
2038    }
2039
2040    fn train(&mut self) {
2041        *self.training.lock().unwrap() = true;
2042    }
2043
2044    fn eval(&mut self) {
2045        *self.training.lock().unwrap() = false;
2046    }
2047
2048    fn is_training(&self) -> bool {
2049        *self.training.lock().unwrap()
2050    }
2051
2052    /// Downcast hook for state-dict loaders that need to populate
2053    /// BN's running mean / variance / `num_batches_tracked` (#984).
2054    /// Returning `Some(self)` lets a generic loader walking
2055    /// [`Module::named_modules`] downcast a `&dyn Module<T>` to
2056    /// `&BatchNorm2d<T>` and call the typed setters.
2057    fn as_any(&self) -> Option<&dyn std::any::Any> {
2058        Some(self)
2059    }
2060}
2061
2062// ---------------------------------------------------------------------------
2063// BatchNorm2dBackward
2064// ---------------------------------------------------------------------------
2065
2066/// Backward node for `BatchNorm2d`.
2067///
2068/// Given the forward:
2069///
2070/// ```text
2071/// x_hat = (x - mean) / sqrt(var + eps)
2072/// y = weight * x_hat + bias          (if affine)
2073/// ```
2074///
2075/// The gradients are:
2076///
2077/// - `grad_bias[c]  = sum(grad_output[:, c, :, :])` over `(B, H, W)`
2078/// - `grad_weight[c] = sum(grad_output[:, c, :, :] * x_hat[:, c, :, :])` over `(B, H, W)`
2079/// - `grad_input`:
2080///   ```text
2081///   dl_dx_hat = grad_output * weight              (if affine, else grad_output)
2082///   grad_input = (1 / sqrt(var + eps)) *
2083///       (dl_dx_hat - mean(dl_dx_hat) - x_hat * mean(dl_dx_hat * x_hat))
2084///   ```
2085///   where the means are taken over `(B, H, W)`.
2086///
2087/// Inputs stored: `[input, weight?, bias?]`.
2088/// Returns: `[grad_input, grad_weight?, grad_bias?]`.
2089#[derive(Debug)]
2090struct BatchNorm2dBackward<T: Float> {
2091    input: Tensor<T>,
2092    /// Pre-computed normalized values `(x - mean) / sqrt(var + eps)`.
2093    /// Empty on the GPU path (the GPU backward kernel recomputes from `input`).
2094    x_hat: Tensor<T>,
2095    weight: Option<Tensor<T>>,
2096    bias: Option<Tensor<T>>,
2097    /// Per-channel batch variances (biased). Empty on the GPU path.
2098    chan_var: Vec<f64>,
2099    eps: f64,
2100    affine: bool,
2101    /// GPU backward (#1449): forward training-mode flag (kernel recomputes
2102    /// batch stats) vs. eval (uses the running snapshots).
2103    is_training: bool,
2104    /// Running-mean snapshot for the GPU eval-mode backward (`[channels]`).
2105    running_mean: Vec<f64>,
2106    /// Running-var snapshot for the GPU eval-mode backward (`[channels]`).
2107    running_var: Vec<f64>,
2108}
2109
2110impl<T: Float> GradFn<T> for BatchNorm2dBackward<T> {
2111    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2112        let shape = self.input.shape();
2113        let batch = shape[0];
2114        let channels = shape[1];
2115        let height = shape[2];
2116        let width = shape[3];
2117        let spatial = height * width;
2118        let count = batch * spatial;
2119        let count_t = T::from(count).unwrap();
2120
2121        // GPU-native backward (#1449): compute grad_input/weight/bias on-device
2122        // — NO `.cpu()` round trip (R-CODE-4). Mirrors the LayerNormBackward GPU
2123        // fast path above.
2124        if self.input.is_cuda() && is_f32::<T>() {
2125            let weight_dev;
2126            let weight_buf = match self.weight.as_ref() {
2127                Some(w) => w,
2128                None => {
2129                    weight_dev = ferrotorch_core::creation::ones::<T>(&[channels])?
2130                        .to(self.input.device())?;
2131                    &weight_dev
2132                }
2133            };
2134            if let Some(grads) = batch_norm_gpu_backward(
2135                &self.input,
2136                grad_output,
2137                weight_buf,
2138                &self.running_mean,
2139                &self.running_var,
2140                batch,
2141                channels,
2142                spatial,
2143                self.eps,
2144                self.is_training,
2145                self.affine,
2146                self.weight.as_ref().is_some_and(|w| w.requires_grad()),
2147                self.bias.as_ref().is_some_and(|b| b.requires_grad()),
2148            )? {
2149                return Ok(grads);
2150            }
2151            return Err(FerrotorchError::NotImplementedOnCuda {
2152                op: "BatchNorm2dBackward",
2153            });
2154        }
2155        if self.input.is_cuda() {
2156            return Err(FerrotorchError::NotImplementedOnCuda {
2157                op: "BatchNorm2dBackward",
2158            });
2159        }
2160        let go_data = grad_output.data()?;
2161        let x_hat_data = self.x_hat.data()?;
2162
2163        let weight_data = self.weight.as_ref().map(|w| w.data().unwrap().to_vec());
2164
2165        let mut grad_input = vec![zero::<T>(); self.input.numel()];
2166        let mut grad_weight = vec![zero::<T>(); channels];
2167        let mut grad_bias = vec![zero::<T>(); channels];
2168
2169        for c in 0..channels {
2170            let var_f64 = self.chan_var[c];
2171            let inv_std = T::from(1.0 / (var_f64 + self.eps).sqrt()).unwrap();
2172
2173            // First pass: accumulate sums for the VJP.
2174            let mut dl_dx_hat_sum = zero::<T>();
2175            let mut dl_dx_hat_x_hat_sum = zero::<T>();
2176
2177            for b in 0..batch {
2178                let base = b * channels * spatial + c * spatial;
2179                for s in 0..spatial {
2180                    let idx = base + s;
2181                    let x_h = x_hat_data[idx];
2182                    let go = go_data[idx];
2183
2184                    let dl_dx_hat = if self.affine {
2185                        go * weight_data.as_ref().unwrap()[c]
2186                    } else {
2187                        go
2188                    };
2189
2190                    dl_dx_hat_sum += dl_dx_hat;
2191                    dl_dx_hat_x_hat_sum += dl_dx_hat * x_h;
2192
2193                    if self.affine {
2194                        grad_weight[c] += go * x_h;
2195                        grad_bias[c] += go;
2196                    }
2197                }
2198            }
2199
2200            let dl_dx_hat_mean = dl_dx_hat_sum / count_t;
2201            let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / count_t;
2202
2203            // Second pass: compute grad_input.
2204            for b in 0..batch {
2205                let base = b * channels * spatial + c * spatial;
2206                for s in 0..spatial {
2207                    let idx = base + s;
2208                    let x_h = x_hat_data[idx];
2209                    let go = go_data[idx];
2210
2211                    let dl_dx_hat = if self.affine {
2212                        go * weight_data.as_ref().unwrap()[c]
2213                    } else {
2214                        go
2215                    };
2216
2217                    grad_input[idx] =
2218                        inv_std * (dl_dx_hat - dl_dx_hat_mean - x_h * dl_dx_hat_x_hat_mean);
2219                }
2220            }
2221        }
2222
2223        let grad_input_tensor = Tensor::from_storage(
2224            TensorStorage::cpu(grad_input),
2225            self.input.shape().to_vec(),
2226            false,
2227        )?;
2228
2229        let grad_weight_out = if self.affine {
2230            if let Some(ref w) = self.weight {
2231                if w.requires_grad() {
2232                    Some(Tensor::from_storage(
2233                        TensorStorage::cpu(grad_weight),
2234                        vec![channels],
2235                        false,
2236                    )?)
2237                } else {
2238                    None
2239                }
2240            } else {
2241                None
2242            }
2243        } else {
2244            None
2245        };
2246
2247        let grad_bias_out = if self.affine {
2248            if let Some(ref b) = self.bias {
2249                if b.requires_grad() {
2250                    Some(Tensor::from_storage(
2251                        TensorStorage::cpu(grad_bias),
2252                        vec![channels],
2253                        false,
2254                    )?)
2255                } else {
2256                    None
2257                }
2258            } else {
2259                None
2260            }
2261        } else {
2262            None
2263        };
2264
2265        // Match `inputs()`: when `affine == false` the forward registered only
2266        // `input` as a differentiable leaf, so the grad vec must be length 1.
2267        // Returning weight/bias slots for a 1-input graph makes the autograd
2268        // engine reject the node. Mirrors PyTorch's `grad_input_mask` in
2269        // `aten/src/ATen/native/Normalization.cpp:322-330` (#1567).
2270        if self.affine {
2271            Ok(vec![
2272                Some(grad_input_tensor),
2273                grad_weight_out,
2274                grad_bias_out,
2275            ])
2276        } else {
2277            Ok(vec![Some(grad_input_tensor)])
2278        }
2279    }
2280
2281    fn inputs(&self) -> Vec<&Tensor<T>> {
2282        let mut v: Vec<&Tensor<T>> = vec![&self.input];
2283        if let Some(ref w) = self.weight {
2284            v.push(w);
2285        }
2286        if let Some(ref b) = self.bias {
2287            v.push(b);
2288        }
2289        v
2290    }
2291
2292    fn name(&self) -> &'static str {
2293        "BatchNorm2dBackward"
2294    }
2295}
2296
2297// ===========================================================================
2298// BatchNorm1d
2299// ===========================================================================
2300
2301/// Batch normalization for 2D input `[N, C]` or 3D input `[N, C, L]`.
2302///
2303/// Applies per-channel normalization:
2304///
2305/// ```text
2306/// y = (x - mean) / sqrt(var + eps) * weight + bias
2307/// ```
2308///
2309/// During **training**, `mean` and `var` are computed from the current
2310/// mini-batch over the `(N,)` or `(N, L)` dimensions, and exponential
2311/// moving averages are maintained in `running_mean` and `running_var`.
2312///
2313/// During **evaluation**, the accumulated `running_mean` and `running_var`
2314/// are used instead of batch statistics.
2315///
2316/// Matches `torch.nn.BatchNorm1d`.
2317pub struct BatchNorm1d<T: Float> {
2318    /// Number of channels (features) `C`.
2319    pub num_features: usize,
2320    /// Small constant for numerical stability.
2321    pub eps: f64,
2322    /// Momentum for the running mean / variance update
2323    /// (`running = (1 - momentum) * running + momentum * batch`).
2324    pub momentum: f64,
2325    /// Whether to apply a learnable affine transform.
2326    pub affine: bool,
2327    /// Learnable scale (gamma), shape `[C]`. `None` when `affine == false`.
2328    pub weight: Option<Parameter<T>>,
2329    /// Learnable shift (beta), shape `[C]`. `None` when `affine == false`.
2330    pub bias: Option<Parameter<T>>,
2331    /// Exponential moving average of per-channel means.
2332    running_mean: Mutex<Vec<f64>>,
2333    /// Exponential moving average of per-channel variances.
2334    running_var: Mutex<Vec<f64>>,
2335    /// Number of forward calls in training mode.
2336    num_batches_tracked: Mutex<usize>,
2337    /// Whether the layer is in training mode.
2338    training: Mutex<bool>,
2339}
2340
2341impl<T: Float> std::fmt::Debug for BatchNorm1d<T> {
2342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2343        f.debug_struct("BatchNorm1d")
2344            .field("num_features", &self.num_features)
2345            .field("eps", &self.eps)
2346            .field("momentum", &self.momentum)
2347            .field("affine", &self.affine)
2348            .field("weight", &self.weight)
2349            .field("bias", &self.bias)
2350            .field("training", &self.training)
2351            .finish()
2352    }
2353}
2354
2355impl<T: Float> BatchNorm1d<T> {
2356    /// Create a new `BatchNorm1d` layer.
2357    ///
2358    /// # Arguments
2359    ///
2360    /// * `num_features` - Number of channels `C`.
2361    /// * `eps` - Numerical stability constant (default: `1e-5`).
2362    /// * `momentum` - Running-statistics momentum (default: `0.1`).
2363    /// * `affine` - Whether to include learnable weight and bias.
2364    pub fn new(
2365        num_features: usize,
2366        eps: f64,
2367        momentum: f64,
2368        affine: bool,
2369    ) -> FerrotorchResult<Self> {
2370        if num_features == 0 {
2371            return Err(FerrotorchError::InvalidArgument {
2372                message: "BatchNorm1d: num_features must be positive".into(),
2373            });
2374        }
2375
2376        let weight = if affine {
2377            Some(Parameter::ones(&[num_features])?)
2378        } else {
2379            None
2380        };
2381
2382        let bias = if affine {
2383            Some(Parameter::zeros(&[num_features])?)
2384        } else {
2385            None
2386        };
2387
2388        Ok(Self {
2389            num_features,
2390            eps,
2391            momentum,
2392            affine,
2393            weight,
2394            bias,
2395            running_mean: Mutex::new(vec![0.0; num_features]),
2396            running_var: Mutex::new(vec![1.0; num_features]),
2397            num_batches_tracked: Mutex::new(0),
2398            training: Mutex::new(true),
2399        })
2400    }
2401
2402    /// Access the current running mean (snapshot copy).
2403    pub fn running_mean(&self) -> Vec<f64> {
2404        self.running_mean.lock().unwrap().clone()
2405    }
2406
2407    /// Access the current running variance (snapshot copy).
2408    pub fn running_var(&self) -> Vec<f64> {
2409        self.running_var.lock().unwrap().clone()
2410    }
2411
2412    /// Number of training batches tracked so far.
2413    pub fn num_batches_tracked(&self) -> usize {
2414        *self.num_batches_tracked.lock().unwrap()
2415    }
2416
2417    /// Set the running mean from a slice of length [`num_features`].
2418    ///
2419    /// See [`BatchNorm2d::set_running_mean`] for full semantics — the
2420    /// shape, validation, storage, and round-trip contract are
2421    /// identical (#984).
2422    ///
2423    /// [`num_features`]: Self::num_features
2424    pub fn set_running_mean(&self, value: &[T]) -> FerrotorchResult<()> {
2425        if value.len() != self.num_features {
2426            return Err(FerrotorchError::ShapeMismatch {
2427                message: format!(
2428                    "BatchNorm1d::set_running_mean: expected slice of length \
2429                     num_features={}, got {}",
2430                    self.num_features,
2431                    value.len()
2432                ),
2433            });
2434        }
2435        for (i, x) in value.iter().enumerate() {
2436            if !num_traits::Float::is_finite(*x) {
2437                return Err(FerrotorchError::InvalidArgument {
2438                    message: format!(
2439                        "BatchNorm1d::set_running_mean: non-finite value at \
2440                         index {i} (running_mean must be finite)"
2441                    ),
2442                });
2443            }
2444        }
2445        let mut rm = self.running_mean.lock().unwrap();
2446        for (slot, x) in rm.iter_mut().zip(value.iter()) {
2447            *slot = x.to_f64().unwrap();
2448        }
2449        Ok(())
2450    }
2451
2452    /// Set the running variance from a slice of length [`num_features`].
2453    ///
2454    /// See [`BatchNorm2d::set_running_var`] for full semantics —
2455    /// validation rejects wrong-length, non-finite, and negative
2456    /// entries (#984).
2457    ///
2458    /// [`num_features`]: Self::num_features
2459    pub fn set_running_var(&self, value: &[T]) -> FerrotorchResult<()> {
2460        if value.len() != self.num_features {
2461            return Err(FerrotorchError::ShapeMismatch {
2462                message: format!(
2463                    "BatchNorm1d::set_running_var: expected slice of length \
2464                     num_features={}, got {}",
2465                    self.num_features,
2466                    value.len()
2467                ),
2468            });
2469        }
2470        let zero_t = zero::<T>();
2471        for (i, x) in value.iter().enumerate() {
2472            if !num_traits::Float::is_finite(*x) {
2473                return Err(FerrotorchError::InvalidArgument {
2474                    message: format!(
2475                        "BatchNorm1d::set_running_var: non-finite value at \
2476                         index {i} (running_var must be finite)"
2477                    ),
2478                });
2479            }
2480            if *x < zero_t {
2481                return Err(FerrotorchError::InvalidArgument {
2482                    message: format!(
2483                        "BatchNorm1d::set_running_var: negative value {} at \
2484                         index {i} (running_var must be non-negative)",
2485                        x.to_f64().unwrap()
2486                    ),
2487                });
2488            }
2489        }
2490        let mut rv = self.running_var.lock().unwrap();
2491        for (slot, x) in rv.iter_mut().zip(value.iter()) {
2492            *slot = x.to_f64().unwrap();
2493        }
2494        Ok(())
2495    }
2496
2497    /// Set the number of training batches tracked. (#984)
2498    ///
2499    /// See [`BatchNorm2d::set_num_batches_tracked`] for semantics.
2500    pub fn set_num_batches_tracked(&self, value: usize) -> FerrotorchResult<()> {
2501        let mut nbt = self.num_batches_tracked.lock().unwrap();
2502        *nbt = value;
2503        Ok(())
2504    }
2505}
2506
2507impl<T: Float> Module<T> for BatchNorm1d<T> {
2508    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2509        let shape = input.shape().to_vec();
2510        let ndim = shape.len();
2511
2512        // Accept 2D [N, C] or 3D [N, C, L].
2513        if ndim != 2 && ndim != 3 {
2514            return Err(FerrotorchError::ShapeMismatch {
2515                message: format!(
2516                    "BatchNorm1d: expected 2D [N, C] or 3D [N, C, L] input, got {:?}",
2517                    shape
2518                ),
2519            });
2520        }
2521
2522        let batch = shape[0];
2523        let channels = shape[1];
2524        let length = if ndim == 3 { shape[2] } else { 1 };
2525
2526        if channels != self.num_features {
2527            return Err(FerrotorchError::ShapeMismatch {
2528                message: format!(
2529                    "BatchNorm1d: expected {} channels, got {}",
2530                    self.num_features, channels
2531                ),
2532            });
2533        }
2534
2535        // Edge case: batch size 0.
2536        if batch == 0 {
2537            return Ok(input.clone());
2538        }
2539
2540        // torch `_verify_batch_size` guard (`torch/nn/functional.py:2811-2814`):
2541        // in training mode the per-channel element count (numel / channels =
2542        // batch * length) must exceed 1 — variance is undefined with a single
2543        // sample. Fires BEFORE the CPU/GPU dispatch so both paths reject
2544        // identically (#1558).
2545        if *self.training.lock().unwrap() && batch * length <= 1 {
2546            return Err(FerrotorchError::InvalidArgument {
2547                message: format!(
2548                    "Expected more than 1 value per channel when training, got input size {:?}",
2549                    shape
2550                ),
2551            });
2552        }
2553
2554        // GPU fast path (#1449): per-channel normalize over (N,) or (N, L).
2555        if input.is_cuda() {
2556            if is_f32::<T>() {
2557                let is_training = *self.training.lock().unwrap();
2558                if let Some(handle) = batch_norm_gpu_forward(
2559                    input,
2560                    self.weight.as_ref().map(|w| w.tensor()),
2561                    self.bias.as_ref().map(|b| b.tensor()),
2562                    &self.running_mean,
2563                    &self.running_var,
2564                    &self.num_batches_tracked,
2565                    self.momentum,
2566                    self.eps,
2567                    channels,
2568                    length,
2569                    is_training,
2570                )? {
2571                    return if is_grad_enabled() && input.requires_grad() {
2572                        let grad_fn = Arc::new(BatchNorm1dBackward {
2573                            input: input.clone(),
2574                            x_hat: Tensor::from_storage(
2575                                TensorStorage::cpu(Vec::new()),
2576                                vec![0],
2577                                false,
2578                            )?,
2579                            weight: self.weight.as_ref().map(|w| w.tensor().clone()),
2580                            bias: self.bias.as_ref().map(|b| b.tensor().clone()),
2581                            chan_var: Vec::new(),
2582                            eps: self.eps,
2583                            affine: self.affine,
2584                            is_training,
2585                            running_mean: self.running_mean.lock().unwrap().clone(),
2586                            running_var: self.running_var.lock().unwrap().clone(),
2587                        });
2588                        Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
2589                    } else {
2590                        Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
2591                    };
2592                }
2593            }
2594            return Err(FerrotorchError::NotImplementedOnCuda {
2595                op: "BatchNorm1d::forward",
2596            });
2597        }
2598        let input_data = input.data()?;
2599        let eps_t = T::from(self.eps).unwrap();
2600
2601        let weight_data = self.weight.as_ref().map(|w| w.tensor().data().unwrap());
2602        let bias_data = self.bias.as_ref().map(|b| b.tensor().data().unwrap());
2603
2604        let is_training = *self.training.lock().unwrap();
2605
2606        let mut chan_mean = vec![zero::<T>(); channels];
2607        let mut chan_var = vec![zero::<T>(); channels];
2608
2609        if is_training {
2610            let count = batch * length;
2611            let count_t = T::from(count).unwrap();
2612
2613            for c in 0..channels {
2614                let mut s = zero::<T>();
2615                for b in 0..batch {
2616                    let base = b * channels * length + c * length;
2617                    for l in 0..length {
2618                        s += input_data[base + l];
2619                    }
2620                }
2621                chan_mean[c] = s / count_t;
2622
2623                let mut var_sum = zero::<T>();
2624                for b in 0..batch {
2625                    let base = b * channels * length + c * length;
2626                    for l in 0..length {
2627                        let d = input_data[base + l] - chan_mean[c];
2628                        var_sum += d * d;
2629                    }
2630                }
2631                chan_var[c] = var_sum / count_t;
2632            }
2633
2634            // Update running statistics.
2635            {
2636                let mut rm = self.running_mean.lock().unwrap();
2637                let mut rv = self.running_var.lock().unwrap();
2638                let mut nbt = self.num_batches_tracked.lock().unwrap();
2639                *nbt += 1;
2640
2641                let mom = self.momentum;
2642                let bessel = if count > 1 {
2643                    count as f64 / (count as f64 - 1.0)
2644                } else {
2645                    1.0
2646                };
2647
2648                for c in 0..channels {
2649                    let batch_mean_f64 = chan_mean[c].to_f64().unwrap();
2650                    let batch_var_f64 = chan_var[c].to_f64().unwrap();
2651
2652                    rm[c] = (1.0 - mom) * rm[c] + mom * batch_mean_f64;
2653                    rv[c] = (1.0 - mom) * rv[c] + mom * batch_var_f64 * bessel;
2654                }
2655            }
2656        } else {
2657            let rm = self.running_mean.lock().unwrap();
2658            let rv = self.running_var.lock().unwrap();
2659
2660            for c in 0..channels {
2661                chan_mean[c] = T::from(rm[c]).unwrap();
2662                chan_var[c] = T::from(rv[c]).unwrap();
2663            }
2664        }
2665
2666        let mut output = vec![zero::<T>(); input.numel()];
2667
2668        let mut inv_std = vec![zero::<T>(); channels];
2669        let need_x_hat = is_grad_enabled() && input.requires_grad();
2670        let mut x_hat_data = if need_x_hat {
2671            Vec::with_capacity(input.numel())
2672        } else {
2673            Vec::new()
2674        };
2675
2676        for c in 0..channels {
2677            inv_std[c] = (chan_var[c] + eps_t).sqrt().recip();
2678        }
2679
2680        for b in 0..batch {
2681            for c in 0..channels {
2682                let base = b * channels * length + c * length;
2683                for l in 0..length {
2684                    let idx = base + l;
2685                    let normed = (input_data[idx] - chan_mean[c]) * inv_std[c];
2686
2687                    if need_x_hat {
2688                        x_hat_data.push(normed);
2689                    }
2690
2691                    if self.affine {
2692                        let w = weight_data.as_ref().unwrap();
2693                        let bi = bias_data.as_ref().unwrap();
2694                        output[idx] = normed * w[c] + bi[c];
2695                    } else {
2696                        output[idx] = normed;
2697                    }
2698                }
2699            }
2700        }
2701
2702        let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
2703
2704        if is_grad_enabled() && input.requires_grad() {
2705            let weight_tensor = self.weight.as_ref().map(|w| w.tensor().clone());
2706            let bias_tensor = self.bias.as_ref().map(|b| b.tensor().clone());
2707
2708            let grad_fn = Arc::new(BatchNorm1dBackward {
2709                input: input.clone(),
2710                x_hat: Tensor::from_storage(TensorStorage::cpu(x_hat_data), shape.to_vec(), false)?,
2711                weight: weight_tensor,
2712                bias: bias_tensor,
2713                chan_var: chan_var.iter().map(|v| v.to_f64().unwrap()).collect(),
2714                eps: self.eps,
2715                affine: self.affine,
2716                is_training,
2717                running_mean: self.running_mean.lock().unwrap().clone(),
2718                running_var: self.running_var.lock().unwrap().clone(),
2719            });
2720
2721            Tensor::from_operation(
2722                TensorStorage::cpu(result.data()?.to_vec()),
2723                result.shape().to_vec(),
2724                grad_fn,
2725            )
2726        } else {
2727            Ok(result)
2728        }
2729    }
2730
2731    fn parameters(&self) -> Vec<&Parameter<T>> {
2732        match (&self.weight, &self.bias) {
2733            (Some(w), Some(b)) => vec![w, b],
2734            _ => vec![],
2735        }
2736    }
2737
2738    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2739        match (&mut self.weight, &mut self.bias) {
2740            (Some(w), Some(b)) => vec![w, b],
2741            _ => vec![],
2742        }
2743    }
2744
2745    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2746        match (&self.weight, &self.bias) {
2747            (Some(w), Some(b)) => vec![("weight".to_string(), w), ("bias".to_string(), b)],
2748            _ => vec![],
2749        }
2750    }
2751
2752    fn train(&mut self) {
2753        *self.training.lock().unwrap() = true;
2754    }
2755
2756    fn eval(&mut self) {
2757        *self.training.lock().unwrap() = false;
2758    }
2759
2760    fn is_training(&self) -> bool {
2761        *self.training.lock().unwrap()
2762    }
2763
2764    /// Downcast hook for state-dict loaders (#984). See
2765    /// [`BatchNorm2d::as_any`].
2766    fn as_any(&self) -> Option<&dyn std::any::Any> {
2767        Some(self)
2768    }
2769}
2770
2771// ---------------------------------------------------------------------------
2772// BatchNorm1dBackward
2773// ---------------------------------------------------------------------------
2774
2775/// Backward node for `BatchNorm1d`.
2776///
2777/// Same math as `BatchNorm2dBackward` but over `(N,)` or `(N, L)` spatial dims
2778/// instead of `(N, H, W)`.
2779#[derive(Debug)]
2780struct BatchNorm1dBackward<T: Float> {
2781    input: Tensor<T>,
2782    x_hat: Tensor<T>,
2783    weight: Option<Tensor<T>>,
2784    bias: Option<Tensor<T>>,
2785    chan_var: Vec<f64>,
2786    eps: f64,
2787    affine: bool,
2788    /// GPU backward (#1449): forward training-mode flag.
2789    is_training: bool,
2790    /// Running-mean snapshot for the GPU eval-mode backward (`[channels]`).
2791    running_mean: Vec<f64>,
2792    /// Running-var snapshot for the GPU eval-mode backward (`[channels]`).
2793    running_var: Vec<f64>,
2794}
2795
2796impl<T: Float> GradFn<T> for BatchNorm1dBackward<T> {
2797    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2798        let shape = self.input.shape();
2799        let ndim = shape.len();
2800        let batch = shape[0];
2801        let channels = shape[1];
2802        let length = if ndim == 3 { shape[2] } else { 1 };
2803        let count = batch * length;
2804        let count_t = T::from(count).unwrap();
2805
2806        // GPU-native backward (#1449): on-device, NO `.cpu()` round trip.
2807        if self.input.is_cuda() && is_f32::<T>() {
2808            let weight_dev;
2809            let weight_buf = match self.weight.as_ref() {
2810                Some(w) => w,
2811                None => {
2812                    weight_dev = ferrotorch_core::creation::ones::<T>(&[channels])?
2813                        .to(self.input.device())?;
2814                    &weight_dev
2815                }
2816            };
2817            if let Some(grads) = batch_norm_gpu_backward(
2818                &self.input,
2819                grad_output,
2820                weight_buf,
2821                &self.running_mean,
2822                &self.running_var,
2823                batch,
2824                channels,
2825                length,
2826                self.eps,
2827                self.is_training,
2828                self.affine,
2829                self.weight.as_ref().is_some_and(|w| w.requires_grad()),
2830                self.bias.as_ref().is_some_and(|b| b.requires_grad()),
2831            )? {
2832                return Ok(grads);
2833            }
2834            return Err(FerrotorchError::NotImplementedOnCuda {
2835                op: "BatchNorm1dBackward",
2836            });
2837        }
2838        if self.input.is_cuda() {
2839            return Err(FerrotorchError::NotImplementedOnCuda {
2840                op: "BatchNorm1dBackward",
2841            });
2842        }
2843        let go_data = grad_output.data()?;
2844        let x_hat_data = self.x_hat.data()?;
2845
2846        let weight_data = self.weight.as_ref().map(|w| w.data().unwrap().to_vec());
2847
2848        let mut grad_input = vec![zero::<T>(); self.input.numel()];
2849        let mut grad_weight = vec![zero::<T>(); channels];
2850        let mut grad_bias = vec![zero::<T>(); channels];
2851
2852        for c in 0..channels {
2853            let var_f64 = self.chan_var[c];
2854            let inv_std = T::from(1.0 / (var_f64 + self.eps).sqrt()).unwrap();
2855
2856            let mut dl_dx_hat_sum = zero::<T>();
2857            let mut dl_dx_hat_x_hat_sum = zero::<T>();
2858
2859            for b in 0..batch {
2860                let base = b * channels * length + c * length;
2861                for l in 0..length {
2862                    let idx = base + l;
2863                    let x_h = x_hat_data[idx];
2864                    let go = go_data[idx];
2865
2866                    let dl_dx_hat = if self.affine {
2867                        go * weight_data.as_ref().unwrap()[c]
2868                    } else {
2869                        go
2870                    };
2871
2872                    dl_dx_hat_sum += dl_dx_hat;
2873                    dl_dx_hat_x_hat_sum += dl_dx_hat * x_h;
2874
2875                    if self.affine {
2876                        grad_weight[c] += go * x_h;
2877                        grad_bias[c] += go;
2878                    }
2879                }
2880            }
2881
2882            let dl_dx_hat_mean = dl_dx_hat_sum / count_t;
2883            let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / count_t;
2884
2885            for b in 0..batch {
2886                let base = b * channels * length + c * length;
2887                for l in 0..length {
2888                    let idx = base + l;
2889                    let x_h = x_hat_data[idx];
2890                    let go = go_data[idx];
2891
2892                    let dl_dx_hat = if self.affine {
2893                        go * weight_data.as_ref().unwrap()[c]
2894                    } else {
2895                        go
2896                    };
2897
2898                    grad_input[idx] =
2899                        inv_std * (dl_dx_hat - dl_dx_hat_mean - x_h * dl_dx_hat_x_hat_mean);
2900                }
2901            }
2902        }
2903
2904        let grad_input_tensor = Tensor::from_storage(
2905            TensorStorage::cpu(grad_input),
2906            self.input.shape().to_vec(),
2907            false,
2908        )?;
2909
2910        let grad_weight_out = if self.affine {
2911            if let Some(ref w) = self.weight {
2912                if w.requires_grad() {
2913                    Some(Tensor::from_storage(
2914                        TensorStorage::cpu(grad_weight),
2915                        vec![channels],
2916                        false,
2917                    )?)
2918                } else {
2919                    None
2920                }
2921            } else {
2922                None
2923            }
2924        } else {
2925            None
2926        };
2927
2928        let grad_bias_out = if self.affine {
2929            if let Some(ref b) = self.bias {
2930                if b.requires_grad() {
2931                    Some(Tensor::from_storage(
2932                        TensorStorage::cpu(grad_bias),
2933                        vec![channels],
2934                        false,
2935                    )?)
2936                } else {
2937                    None
2938                }
2939            } else {
2940                None
2941            }
2942        } else {
2943            None
2944        };
2945
2946        // Match `inputs()`: when `affine == false` the forward registered only
2947        // `input` as a differentiable leaf, so the grad vec must be length 1.
2948        // Returning weight/bias slots for a 1-input graph makes the autograd
2949        // engine reject the node. Mirrors PyTorch's `grad_input_mask` in
2950        // `aten/src/ATen/native/Normalization.cpp:322-330` (#1567).
2951        if self.affine {
2952            Ok(vec![
2953                Some(grad_input_tensor),
2954                grad_weight_out,
2955                grad_bias_out,
2956            ])
2957        } else {
2958            Ok(vec![Some(grad_input_tensor)])
2959        }
2960    }
2961
2962    fn inputs(&self) -> Vec<&Tensor<T>> {
2963        let mut v: Vec<&Tensor<T>> = vec![&self.input];
2964        if let Some(ref w) = self.weight {
2965            v.push(w);
2966        }
2967        if let Some(ref b) = self.bias {
2968            v.push(b);
2969        }
2970        v
2971    }
2972
2973    fn name(&self) -> &'static str {
2974        "BatchNorm1dBackward"
2975    }
2976}
2977
2978// ===========================================================================
2979// BatchNorm3d — CL-434
2980// ===========================================================================
2981
2982/// Batch normalization over 5D inputs (a mini-batch of 3D inputs with an
2983/// additional channel dimension).
2984///
2985/// Applies the transform per channel:
2986///
2987/// ```text
2988/// y = (x - mean) / sqrt(var + eps) * weight + bias
2989/// ```
2990///
2991/// During **training**, `mean` and `var` are computed from the current
2992/// mini-batch over the `(B, D, H, W)` dimensions, and exponential moving
2993/// averages of these statistics are maintained in `running_mean` and
2994/// `running_var`.
2995///
2996/// During **evaluation**, the accumulated `running_mean` and `running_var`
2997/// are used instead of batch statistics.
2998///
2999/// Matches `torch.nn.BatchNorm3d`.
3000pub struct BatchNorm3d<T: Float> {
3001    /// Number of channels (features) `C`.
3002    pub num_features: usize,
3003    /// Small constant for numerical stability.
3004    pub eps: f64,
3005    /// Momentum for the running mean / variance update
3006    /// (`running = (1 - momentum) * running + momentum * batch`).
3007    pub momentum: f64,
3008    /// Whether to apply a learnable affine transform.
3009    pub affine: bool,
3010    /// Learnable scale (gamma), shape `[C]`. `None` when `affine == false`.
3011    pub weight: Option<Parameter<T>>,
3012    /// Learnable shift (beta), shape `[C]`. `None` when `affine == false`.
3013    pub bias: Option<Parameter<T>>,
3014    /// Exponential moving average of per-channel means.
3015    running_mean: Mutex<Vec<f64>>,
3016    /// Exponential moving average of per-channel variances.
3017    running_var: Mutex<Vec<f64>>,
3018    /// Number of forward calls in training mode.
3019    num_batches_tracked: Mutex<usize>,
3020    /// Whether the layer is in training mode.
3021    training: Mutex<bool>,
3022}
3023
3024impl<T: Float> std::fmt::Debug for BatchNorm3d<T> {
3025    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3026        f.debug_struct("BatchNorm3d")
3027            .field("num_features", &self.num_features)
3028            .field("eps", &self.eps)
3029            .field("momentum", &self.momentum)
3030            .field("affine", &self.affine)
3031            .field("weight", &self.weight)
3032            .field("bias", &self.bias)
3033            .field("training", &self.training)
3034            .finish()
3035    }
3036}
3037
3038impl<T: Float> BatchNorm3d<T> {
3039    /// Create a new `BatchNorm3d` layer.
3040    ///
3041    /// # Arguments
3042    ///
3043    /// * `num_features` - Number of channels `C`.
3044    /// * `eps` - Numerical stability constant (default: `1e-5`).
3045    /// * `momentum` - Running-statistics momentum (default: `0.1`).
3046    /// * `affine` - Whether to include learnable weight and bias.
3047    pub fn new(
3048        num_features: usize,
3049        eps: f64,
3050        momentum: f64,
3051        affine: bool,
3052    ) -> FerrotorchResult<Self> {
3053        if num_features == 0 {
3054            return Err(FerrotorchError::InvalidArgument {
3055                message: "BatchNorm3d: num_features must be positive".into(),
3056            });
3057        }
3058
3059        let weight = if affine {
3060            Some(Parameter::ones(&[num_features])?)
3061        } else {
3062            None
3063        };
3064
3065        let bias = if affine {
3066            Some(Parameter::zeros(&[num_features])?)
3067        } else {
3068            None
3069        };
3070
3071        Ok(Self {
3072            num_features,
3073            eps,
3074            momentum,
3075            affine,
3076            weight,
3077            bias,
3078            running_mean: Mutex::new(vec![0.0; num_features]),
3079            running_var: Mutex::new(vec![1.0; num_features]),
3080            num_batches_tracked: Mutex::new(0),
3081            training: Mutex::new(true),
3082        })
3083    }
3084
3085    /// Access the current running mean (snapshot copy).
3086    pub fn running_mean(&self) -> Vec<f64> {
3087        self.running_mean.lock().unwrap().clone()
3088    }
3089
3090    /// Access the current running variance (snapshot copy).
3091    pub fn running_var(&self) -> Vec<f64> {
3092        self.running_var.lock().unwrap().clone()
3093    }
3094
3095    /// Number of training batches tracked so far.
3096    pub fn num_batches_tracked(&self) -> usize {
3097        *self.num_batches_tracked.lock().unwrap()
3098    }
3099
3100    /// Set the running mean from a slice of length [`num_features`].
3101    ///
3102    /// See [`BatchNorm2d::set_running_mean`] for full semantics (#984).
3103    ///
3104    /// [`num_features`]: Self::num_features
3105    pub fn set_running_mean(&self, value: &[T]) -> FerrotorchResult<()> {
3106        if value.len() != self.num_features {
3107            return Err(FerrotorchError::ShapeMismatch {
3108                message: format!(
3109                    "BatchNorm3d::set_running_mean: expected slice of length \
3110                     num_features={}, got {}",
3111                    self.num_features,
3112                    value.len()
3113                ),
3114            });
3115        }
3116        for (i, x) in value.iter().enumerate() {
3117            if !num_traits::Float::is_finite(*x) {
3118                return Err(FerrotorchError::InvalidArgument {
3119                    message: format!(
3120                        "BatchNorm3d::set_running_mean: non-finite value at \
3121                         index {i} (running_mean must be finite)"
3122                    ),
3123                });
3124            }
3125        }
3126        let mut rm = self.running_mean.lock().unwrap();
3127        for (slot, x) in rm.iter_mut().zip(value.iter()) {
3128            *slot = x.to_f64().unwrap();
3129        }
3130        Ok(())
3131    }
3132
3133    /// Set the running variance from a slice of length [`num_features`].
3134    ///
3135    /// See [`BatchNorm2d::set_running_var`] for full semantics — rejects
3136    /// wrong-length, non-finite, and negative entries (#984).
3137    ///
3138    /// [`num_features`]: Self::num_features
3139    pub fn set_running_var(&self, value: &[T]) -> FerrotorchResult<()> {
3140        if value.len() != self.num_features {
3141            return Err(FerrotorchError::ShapeMismatch {
3142                message: format!(
3143                    "BatchNorm3d::set_running_var: expected slice of length \
3144                     num_features={}, got {}",
3145                    self.num_features,
3146                    value.len()
3147                ),
3148            });
3149        }
3150        let zero_t = zero::<T>();
3151        for (i, x) in value.iter().enumerate() {
3152            if !num_traits::Float::is_finite(*x) {
3153                return Err(FerrotorchError::InvalidArgument {
3154                    message: format!(
3155                        "BatchNorm3d::set_running_var: non-finite value at \
3156                         index {i} (running_var must be finite)"
3157                    ),
3158                });
3159            }
3160            if *x < zero_t {
3161                return Err(FerrotorchError::InvalidArgument {
3162                    message: format!(
3163                        "BatchNorm3d::set_running_var: negative value {} at \
3164                         index {i} (running_var must be non-negative)",
3165                        x.to_f64().unwrap()
3166                    ),
3167                });
3168            }
3169        }
3170        let mut rv = self.running_var.lock().unwrap();
3171        for (slot, x) in rv.iter_mut().zip(value.iter()) {
3172            *slot = x.to_f64().unwrap();
3173        }
3174        Ok(())
3175    }
3176
3177    /// Set the number of training batches tracked. (#984)
3178    ///
3179    /// See [`BatchNorm2d::set_num_batches_tracked`] for semantics.
3180    pub fn set_num_batches_tracked(&self, value: usize) -> FerrotorchResult<()> {
3181        let mut nbt = self.num_batches_tracked.lock().unwrap();
3182        *nbt = value;
3183        Ok(())
3184    }
3185}
3186
3187impl<T: Float> Module<T> for BatchNorm3d<T> {
3188    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3189        let shape = input.shape().to_vec();
3190        if shape.len() != 5 {
3191            return Err(FerrotorchError::ShapeMismatch {
3192                message: format!(
3193                    "BatchNorm3d: expected 5D input [B, C, D, H, W], got {:?}",
3194                    shape
3195                ),
3196            });
3197        }
3198
3199        let batch = shape[0];
3200        let channels = shape[1];
3201        let depth = shape[2];
3202        let height = shape[3];
3203        let width = shape[4];
3204        let spatial = depth * height * width;
3205
3206        if channels != self.num_features {
3207            return Err(FerrotorchError::ShapeMismatch {
3208                message: format!(
3209                    "BatchNorm3d: expected {} channels, got {}",
3210                    self.num_features, channels
3211                ),
3212            });
3213        }
3214
3215        if batch == 0 {
3216            return Ok(input.clone());
3217        }
3218
3219        // torch `_verify_batch_size` guard (`torch/nn/functional.py:2811-2814`):
3220        // in training mode the per-channel element count (numel / channels =
3221        // batch * spatial) must exceed 1 — variance is undefined with a single
3222        // sample. Fires BEFORE the CPU/GPU dispatch so both paths reject
3223        // identically (#1558).
3224        if *self.training.lock().unwrap() && batch * spatial <= 1 {
3225            return Err(FerrotorchError::InvalidArgument {
3226                message: format!(
3227                    "Expected more than 1 value per channel when training, got input size {:?}",
3228                    shape
3229                ),
3230            });
3231        }
3232
3233        // GPU fast path (#1449): per-channel normalize over (B, D, H, W).
3234        if input.is_cuda() {
3235            if is_f32::<T>() {
3236                let is_training = *self.training.lock().unwrap();
3237                if let Some(handle) = batch_norm_gpu_forward(
3238                    input,
3239                    self.weight.as_ref().map(|w| w.tensor()),
3240                    self.bias.as_ref().map(|b| b.tensor()),
3241                    &self.running_mean,
3242                    &self.running_var,
3243                    &self.num_batches_tracked,
3244                    self.momentum,
3245                    self.eps,
3246                    channels,
3247                    spatial,
3248                    is_training,
3249                )? {
3250                    return if is_grad_enabled() && input.requires_grad() {
3251                        let grad_fn = Arc::new(BatchNorm3dBackward {
3252                            input: input.clone(),
3253                            x_hat: Tensor::from_storage(
3254                                TensorStorage::cpu(Vec::new()),
3255                                vec![0],
3256                                false,
3257                            )?,
3258                            weight: self.weight.as_ref().map(|w| w.tensor().clone()),
3259                            bias: self.bias.as_ref().map(|b| b.tensor().clone()),
3260                            chan_var: Vec::new(),
3261                            eps: self.eps,
3262                            affine: self.affine,
3263                            is_training,
3264                            running_mean: self.running_mean.lock().unwrap().clone(),
3265                            running_var: self.running_var.lock().unwrap().clone(),
3266                        });
3267                        Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
3268                    } else {
3269                        Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
3270                    };
3271                }
3272            }
3273            return Err(FerrotorchError::NotImplementedOnCuda {
3274                op: "BatchNorm3d::forward",
3275            });
3276        }
3277        let input_data = input.data()?;
3278        let eps_t = T::from(self.eps).unwrap();
3279
3280        let weight_data = self.weight.as_ref().map(|w| w.tensor().data().unwrap());
3281        let bias_data = self.bias.as_ref().map(|b| b.tensor().data().unwrap());
3282
3283        let is_training = *self.training.lock().unwrap();
3284
3285        let mut chan_mean = vec![zero::<T>(); channels];
3286        let mut chan_var = vec![zero::<T>(); channels];
3287
3288        if is_training {
3289            let count = batch * spatial;
3290            let count_t = T::from(count).unwrap();
3291
3292            for c in 0..channels {
3293                let mut sum = zero::<T>();
3294                for b in 0..batch {
3295                    let base = b * channels * spatial + c * spatial;
3296                    for s in 0..spatial {
3297                        sum += input_data[base + s];
3298                    }
3299                }
3300                chan_mean[c] = sum / count_t;
3301
3302                let mut var_sum = zero::<T>();
3303                for b in 0..batch {
3304                    let base = b * channels * spatial + c * spatial;
3305                    for s in 0..spatial {
3306                        let d = input_data[base + s] - chan_mean[c];
3307                        var_sum += d * d;
3308                    }
3309                }
3310                chan_var[c] = var_sum / count_t;
3311            }
3312
3313            // Update running statistics.
3314            {
3315                let mut rm = self.running_mean.lock().unwrap();
3316                let mut rv = self.running_var.lock().unwrap();
3317                let mut nbt = self.num_batches_tracked.lock().unwrap();
3318                *nbt += 1;
3319
3320                let mom = self.momentum;
3321                let bessel = if count > 1 {
3322                    count as f64 / (count as f64 - 1.0)
3323                } else {
3324                    1.0
3325                };
3326
3327                for c in 0..channels {
3328                    let batch_mean_f64 = chan_mean[c].to_f64().unwrap();
3329                    let batch_var_f64 = chan_var[c].to_f64().unwrap();
3330
3331                    rm[c] = (1.0 - mom) * rm[c] + mom * batch_mean_f64;
3332                    rv[c] = (1.0 - mom) * rv[c] + mom * batch_var_f64 * bessel;
3333                }
3334            }
3335        } else {
3336            let rm = self.running_mean.lock().unwrap();
3337            let rv = self.running_var.lock().unwrap();
3338
3339            for c in 0..channels {
3340                chan_mean[c] = T::from(rm[c]).unwrap();
3341                chan_var[c] = T::from(rv[c]).unwrap();
3342            }
3343        }
3344
3345        let mut output = vec![zero::<T>(); input.numel()];
3346
3347        let mut inv_std = vec![zero::<T>(); channels];
3348        let need_x_hat = is_grad_enabled() && input.requires_grad();
3349        let mut x_hat_data = if need_x_hat {
3350            Vec::with_capacity(input.numel())
3351        } else {
3352            Vec::new()
3353        };
3354
3355        for c in 0..channels {
3356            inv_std[c] = (chan_var[c] + eps_t).sqrt().recip();
3357        }
3358
3359        for b in 0..batch {
3360            for c in 0..channels {
3361                let base = b * channels * spatial + c * spatial;
3362                for s in 0..spatial {
3363                    let idx = base + s;
3364                    let normed = (input_data[idx] - chan_mean[c]) * inv_std[c];
3365
3366                    if need_x_hat {
3367                        x_hat_data.push(normed);
3368                    }
3369
3370                    if self.affine {
3371                        let w = weight_data.as_ref().unwrap();
3372                        let bi = bias_data.as_ref().unwrap();
3373                        output[idx] = normed * w[c] + bi[c];
3374                    } else {
3375                        output[idx] = normed;
3376                    }
3377                }
3378            }
3379        }
3380
3381        let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
3382
3383        if is_grad_enabled() && input.requires_grad() {
3384            let weight_tensor = self.weight.as_ref().map(|w| w.tensor().clone());
3385            let bias_tensor = self.bias.as_ref().map(|b| b.tensor().clone());
3386
3387            let grad_fn = Arc::new(BatchNorm3dBackward {
3388                input: input.clone(),
3389                x_hat: Tensor::from_storage(TensorStorage::cpu(x_hat_data), shape.to_vec(), false)?,
3390                weight: weight_tensor,
3391                bias: bias_tensor,
3392                chan_var: chan_var.iter().map(|v| v.to_f64().unwrap()).collect(),
3393                eps: self.eps,
3394                affine: self.affine,
3395                is_training,
3396                running_mean: self.running_mean.lock().unwrap().clone(),
3397                running_var: self.running_var.lock().unwrap().clone(),
3398            });
3399
3400            Tensor::from_operation(
3401                TensorStorage::cpu(result.data()?.to_vec()),
3402                result.shape().to_vec(),
3403                grad_fn,
3404            )
3405        } else {
3406            Ok(result)
3407        }
3408    }
3409
3410    fn parameters(&self) -> Vec<&Parameter<T>> {
3411        match (&self.weight, &self.bias) {
3412            (Some(w), Some(b)) => vec![w, b],
3413            _ => vec![],
3414        }
3415    }
3416
3417    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
3418        match (&mut self.weight, &mut self.bias) {
3419            (Some(w), Some(b)) => vec![w, b],
3420            _ => vec![],
3421        }
3422    }
3423
3424    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
3425        match (&self.weight, &self.bias) {
3426            (Some(w), Some(b)) => vec![("weight".to_string(), w), ("bias".to_string(), b)],
3427            _ => vec![],
3428        }
3429    }
3430
3431    fn train(&mut self) {
3432        *self.training.lock().unwrap() = true;
3433    }
3434
3435    fn eval(&mut self) {
3436        *self.training.lock().unwrap() = false;
3437    }
3438
3439    fn is_training(&self) -> bool {
3440        *self.training.lock().unwrap()
3441    }
3442
3443    /// Downcast hook for state-dict loaders (#984). See
3444    /// [`BatchNorm2d::as_any`].
3445    fn as_any(&self) -> Option<&dyn std::any::Any> {
3446        Some(self)
3447    }
3448}
3449
3450// ---------------------------------------------------------------------------
3451// BatchNorm3dBackward
3452// ---------------------------------------------------------------------------
3453
3454/// Backward node for `BatchNorm3d`.
3455///
3456/// Same math as `BatchNorm2dBackward` but over `(B, D, H, W)` spatial dims.
3457#[derive(Debug)]
3458struct BatchNorm3dBackward<T: Float> {
3459    input: Tensor<T>,
3460    x_hat: Tensor<T>,
3461    weight: Option<Tensor<T>>,
3462    bias: Option<Tensor<T>>,
3463    chan_var: Vec<f64>,
3464    eps: f64,
3465    affine: bool,
3466    /// GPU backward (#1449): forward training-mode flag.
3467    is_training: bool,
3468    /// Running-mean snapshot for the GPU eval-mode backward (`[channels]`).
3469    running_mean: Vec<f64>,
3470    /// Running-var snapshot for the GPU eval-mode backward (`[channels]`).
3471    running_var: Vec<f64>,
3472}
3473
3474impl<T: Float> GradFn<T> for BatchNorm3dBackward<T> {
3475    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
3476        let shape = self.input.shape();
3477        let batch = shape[0];
3478        let channels = shape[1];
3479        let spatial: usize = shape[2..].iter().product();
3480        let count = batch * spatial;
3481        let count_t = T::from(count).unwrap();
3482
3483        // GPU-native backward (#1449): on-device, NO `.cpu()` round trip.
3484        if self.input.is_cuda() && is_f32::<T>() {
3485            let weight_dev;
3486            let weight_buf = match self.weight.as_ref() {
3487                Some(w) => w,
3488                None => {
3489                    weight_dev = ferrotorch_core::creation::ones::<T>(&[channels])?
3490                        .to(self.input.device())?;
3491                    &weight_dev
3492                }
3493            };
3494            if let Some(grads) = batch_norm_gpu_backward(
3495                &self.input,
3496                grad_output,
3497                weight_buf,
3498                &self.running_mean,
3499                &self.running_var,
3500                batch,
3501                channels,
3502                spatial,
3503                self.eps,
3504                self.is_training,
3505                self.affine,
3506                self.weight.as_ref().is_some_and(|w| w.requires_grad()),
3507                self.bias.as_ref().is_some_and(|b| b.requires_grad()),
3508            )? {
3509                return Ok(grads);
3510            }
3511            return Err(FerrotorchError::NotImplementedOnCuda {
3512                op: "BatchNorm3dBackward",
3513            });
3514        }
3515        if self.input.is_cuda() {
3516            return Err(FerrotorchError::NotImplementedOnCuda {
3517                op: "BatchNorm3dBackward",
3518            });
3519        }
3520        let go_data = grad_output.data()?;
3521        let x_hat_data = self.x_hat.data()?;
3522
3523        let weight_data = self.weight.as_ref().map(|w| w.data().unwrap().to_vec());
3524
3525        let mut grad_input = vec![zero::<T>(); self.input.numel()];
3526        let mut grad_weight = vec![zero::<T>(); channels];
3527        let mut grad_bias = vec![zero::<T>(); channels];
3528
3529        for c in 0..channels {
3530            let var_f64 = self.chan_var[c];
3531            let inv_std = T::from(1.0 / (var_f64 + self.eps).sqrt()).unwrap();
3532
3533            let mut dl_dx_hat_sum = zero::<T>();
3534            let mut dl_dx_hat_x_hat_sum = zero::<T>();
3535
3536            for b in 0..batch {
3537                let base = b * channels * spatial + c * spatial;
3538                for s in 0..spatial {
3539                    let idx = base + s;
3540                    let x_h = x_hat_data[idx];
3541                    let go = go_data[idx];
3542
3543                    let dl_dx_hat = if self.affine {
3544                        go * weight_data.as_ref().unwrap()[c]
3545                    } else {
3546                        go
3547                    };
3548
3549                    dl_dx_hat_sum += dl_dx_hat;
3550                    dl_dx_hat_x_hat_sum += dl_dx_hat * x_h;
3551
3552                    if self.affine {
3553                        grad_weight[c] += go * x_h;
3554                        grad_bias[c] += go;
3555                    }
3556                }
3557            }
3558
3559            let dl_dx_hat_mean = dl_dx_hat_sum / count_t;
3560            let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / count_t;
3561
3562            for b in 0..batch {
3563                let base = b * channels * spatial + c * spatial;
3564                for s in 0..spatial {
3565                    let idx = base + s;
3566                    let x_h = x_hat_data[idx];
3567                    let go = go_data[idx];
3568
3569                    let dl_dx_hat = if self.affine {
3570                        go * weight_data.as_ref().unwrap()[c]
3571                    } else {
3572                        go
3573                    };
3574
3575                    grad_input[idx] =
3576                        inv_std * (dl_dx_hat - dl_dx_hat_mean - x_h * dl_dx_hat_x_hat_mean);
3577                }
3578            }
3579        }
3580
3581        let grad_input_tensor = Tensor::from_storage(
3582            TensorStorage::cpu(grad_input),
3583            self.input.shape().to_vec(),
3584            false,
3585        )?;
3586
3587        let grad_weight_out = if self.affine {
3588            if let Some(ref w) = self.weight {
3589                if w.requires_grad() {
3590                    Some(Tensor::from_storage(
3591                        TensorStorage::cpu(grad_weight),
3592                        vec![channels],
3593                        false,
3594                    )?)
3595                } else {
3596                    None
3597                }
3598            } else {
3599                None
3600            }
3601        } else {
3602            None
3603        };
3604
3605        let grad_bias_out = if self.affine {
3606            if let Some(ref b) = self.bias {
3607                if b.requires_grad() {
3608                    Some(Tensor::from_storage(
3609                        TensorStorage::cpu(grad_bias),
3610                        vec![channels],
3611                        false,
3612                    )?)
3613                } else {
3614                    None
3615                }
3616            } else {
3617                None
3618            }
3619        } else {
3620            None
3621        };
3622
3623        // Match `inputs()`: when `affine == false` the forward registered only
3624        // `input` as a differentiable leaf, so the grad vec must be length 1.
3625        // Returning weight/bias slots for a 1-input graph makes the autograd
3626        // engine reject the node. Mirrors PyTorch's `grad_input_mask` in
3627        // `aten/src/ATen/native/Normalization.cpp:322-330` (#1567).
3628        if self.affine {
3629            Ok(vec![
3630                Some(grad_input_tensor),
3631                grad_weight_out,
3632                grad_bias_out,
3633            ])
3634        } else {
3635            Ok(vec![Some(grad_input_tensor)])
3636        }
3637    }
3638
3639    fn inputs(&self) -> Vec<&Tensor<T>> {
3640        let mut v: Vec<&Tensor<T>> = vec![&self.input];
3641        if let Some(ref w) = self.weight {
3642            v.push(w);
3643        }
3644        if let Some(ref b) = self.bias {
3645            v.push(b);
3646        }
3647        v
3648    }
3649
3650    fn name(&self) -> &'static str {
3651        "BatchNorm3dBackward"
3652    }
3653}
3654
3655// ===========================================================================
3656// LocalResponseNorm — CL-435
3657// ===========================================================================
3658
3659/// Local Response Normalization (cross-channel normalization).
3660///
3661/// Applies the transform:
3662///
3663/// ```text
3664/// output[c] = input[c] / (k + alpha/size * sum(input[j]^2 for j in [c-size/2, c+size/2]))^beta
3665/// ```
3666///
3667/// where the sum is over `size` neighbouring channels (clamped at boundaries).
3668///
3669/// Parameters:
3670/// - `size`: number of neighbouring channels to normalize over.
3671/// - `alpha`: multiplicative factor (default: `1e-4`).
3672/// - `beta`: exponent (default: `0.75`).
3673/// - `k`: additive constant (default: `1.0`).
3674///
3675/// This layer has no learnable parameters.
3676///
3677/// Matches `torch.nn.LocalResponseNorm`.
3678#[derive(Debug, Clone)]
3679pub struct LocalResponseNorm {
3680    pub size: usize,
3681    pub alpha: f64,
3682    pub beta: f64,
3683    pub k: f64,
3684    /// Training-mode flag. Carried for Module-trait consistency; the
3685    /// layer itself is stateless and produces the same output in both
3686    /// modes (matches PyTorch's `LocalResponseNorm` behaviour).
3687    training: bool,
3688}
3689
3690impl LocalResponseNorm {
3691    /// Create a new `LocalResponseNorm` layer.
3692    ///
3693    /// # Arguments
3694    ///
3695    /// * `size` - Number of neighbouring channels used for normalization.
3696    /// * `alpha` - Multiplicative factor (default: `1e-4`).
3697    /// * `beta` - Exponent (default: `0.75`).
3698    /// * `k` - Additive constant (default: `1.0`).
3699    pub fn new(size: usize, alpha: f64, beta: f64, k: f64) -> FerrotorchResult<Self> {
3700        if size == 0 {
3701            return Err(FerrotorchError::InvalidArgument {
3702                message: "LocalResponseNorm: size must be positive".into(),
3703            });
3704        }
3705        Ok(Self {
3706            size,
3707            alpha,
3708            beta,
3709            k,
3710            training: true,
3711        })
3712    }
3713
3714    /// Create with default alpha=1e-4, beta=0.75, k=1.0.
3715    pub fn default_params(size: usize) -> FerrotorchResult<Self> {
3716        Self::new(size, 1e-4, 0.75, 1.0)
3717    }
3718}
3719
3720impl<T: Float> Module<T> for LocalResponseNorm {
3721    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3722        let shape = input.shape().to_vec();
3723        if shape.len() < 3 {
3724            return Err(FerrotorchError::ShapeMismatch {
3725                message: format!(
3726                    "LocalResponseNorm: expected at least 3D input [B, C, ...], got {:?}",
3727                    shape
3728                ),
3729            });
3730        }
3731
3732        let batch = shape[0];
3733        let channels = shape[1];
3734        let spatial: usize = shape[2..].iter().product();
3735
3736        // GPU-native forward (#1449): square → windowed channel sum → affine →
3737        // pow(beta) → divide, on-device. NO `.cpu()` round trip (R-CODE-4).
3738        // The saved `denom` buffer is kept GPU-resident for the backward.
3739        // f32-only.
3740        if input.is_cuda() && is_f32::<T>() {
3741            if let Some(backend) = gpu_backend() {
3742                let (out_h, denom_h) = backend.local_response_norm_f32(
3743                    input.gpu_handle()?,
3744                    batch,
3745                    channels,
3746                    spatial,
3747                    self.size,
3748                    self.alpha as f32,
3749                    self.beta as f32,
3750                    self.k as f32,
3751                )?;
3752                let denom_gpu =
3753                    Tensor::from_storage(TensorStorage::gpu(denom_h), shape.clone(), false)?;
3754                return if is_grad_enabled() && input.requires_grad() {
3755                    Tensor::from_operation(
3756                        TensorStorage::gpu(out_h),
3757                        shape,
3758                        Arc::new(LocalResponseNormBackward {
3759                            input: input.clone(),
3760                            denom: Vec::new(),
3761                            denom_gpu: Some(denom_gpu),
3762                            size: self.size,
3763                            alpha: self.alpha,
3764                            beta: self.beta,
3765                        }),
3766                    )
3767                } else {
3768                    Tensor::from_storage(TensorStorage::gpu(out_h), shape, false)
3769                };
3770            }
3771            return Err(FerrotorchError::NotImplementedOnCuda {
3772                op: "LocalResponseNorm::forward",
3773            });
3774        }
3775        if input.is_cuda() {
3776            return Err(FerrotorchError::NotImplementedOnCuda {
3777                op: "LocalResponseNorm::forward",
3778            });
3779        }
3780        let input_data = input.data()?;
3781        let alpha_t = T::from(self.alpha).unwrap();
3782        let beta_t = T::from(self.beta).unwrap();
3783        let k_t = T::from(self.k).unwrap();
3784        let size_t = T::from(self.size).unwrap();
3785        // Window indexing must match PyTorch's `torch/nn/functional.py:3032-3046`:
3786        //   pad(div, (..., size//2, (size-1)//2)); avg_pool with kernel size.
3787        // In original-channel coordinates the window for output channel `c` is
3788        //   [c - size//2, c - size//2 + size)   (i.e. `size` channels wide,
3789        //                                        with implicit zero pad at edges)
3790        // For ODD `size` this is `[c - half, c + half + 1)` (symmetric).
3791        // For EVEN `size` the right edge is `c + (size+1)/2 = c + half` (not
3792        // `c + half + 1`), so the asymmetric pad shifts the window LEFT by 1.
3793        // The previous `c + half + 1` upper bound made the window `size+1`
3794        // wide for even sizes, producing a 0.006 drift at shape [1,6,3] size=2.
3795        let half = self.size / 2;
3796        let upper = self.size - half; // == (size + 1) / 2
3797
3798        let mut output = vec![zero::<T>(); input.numel()];
3799
3800        // Pre-compute squared values per channel per spatial position.
3801        // Also store the denominator for backward.
3802        let mut denom = vec![zero::<T>(); input.numel()];
3803
3804        // Mirror PyTorch's expression order exactly:
3805        //   div = avg_pool(x*x) * alpha + k          (sq_sum/size, *alpha, +k)
3806        //   div = div ^ beta
3807        //   out = x / div
3808        for b in 0..batch {
3809            for c in 0..channels {
3810                let c_start = c.saturating_sub(half);
3811                let c_end = (c + upper).min(channels);
3812
3813                for s in 0..spatial {
3814                    let mut sq_sum = zero::<T>();
3815                    for j in c_start..c_end {
3816                        let jidx = b * channels * spatial + j * spatial + s;
3817                        sq_sum += input_data[jidx] * input_data[jidx];
3818                    }
3819
3820                    let idx = b * channels * spatial + c * spatial + s;
3821                    let d = sq_sum / size_t * alpha_t + k_t;
3822                    denom[idx] = d;
3823                    output[idx] = input_data[idx] / d.powf(beta_t);
3824                }
3825            }
3826        }
3827
3828        let storage = TensorStorage::cpu(output);
3829
3830        if is_grad_enabled() && input.requires_grad() {
3831            Tensor::from_operation(
3832                storage,
3833                shape,
3834                Arc::new(LocalResponseNormBackward {
3835                    input: input.clone(),
3836                    denom,
3837                    denom_gpu: None,
3838                    size: self.size,
3839                    alpha: self.alpha,
3840                    beta: self.beta,
3841                }),
3842            )
3843        } else {
3844            Tensor::from_storage(storage, shape, false)
3845        }
3846    }
3847
3848    fn parameters(&self) -> Vec<&Parameter<T>> {
3849        vec![]
3850    }
3851
3852    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
3853        vec![]
3854    }
3855
3856    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
3857        vec![]
3858    }
3859
3860    fn train(&mut self) {
3861        self.training = true;
3862    }
3863
3864    fn eval(&mut self) {
3865        self.training = false;
3866    }
3867
3868    fn is_training(&self) -> bool {
3869        self.training
3870    }
3871}
3872
3873// ---------------------------------------------------------------------------
3874// LocalResponseNormBackward
3875// ---------------------------------------------------------------------------
3876
3877/// Backward node for `LocalResponseNorm`.
3878///
3879/// Using:
3880/// ```text
3881/// y_c = x_c * D_c^(-beta)
3882/// ```
3883/// where `D_c = k + (alpha/size) * sum_{j in window} x_j^2`
3884///
3885/// The gradient is:
3886/// ```text
3887/// dy/dx_i = D_i^(-beta) - 2*beta*alpha/size * x_i * sum_{c in window_of(i)} (x_c * D_c^(-beta-1))
3888/// ```
3889/// combined with the chain rule from upstream.
3890#[derive(Debug)]
3891struct LocalResponseNormBackward<T: Float> {
3892    input: Tensor<T>,
3893    /// Pre-computed denominator `D_c` per element (CPU path). Empty on GPU.
3894    denom: Vec<T>,
3895    /// GPU backward (#1449): the GPU-resident `denom` buffer saved by the
3896    /// forward kernel, consumed directly by the backward kernel.
3897    denom_gpu: Option<Tensor<T>>,
3898    size: usize,
3899    alpha: f64,
3900    beta: f64,
3901}
3902
3903impl<T: Float> GradFn<T> for LocalResponseNormBackward<T> {
3904    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
3905        if !self.input.requires_grad() {
3906            return Ok(vec![None]);
3907        }
3908
3909        // GPU-native backward (#1449): consumes the saved GPU `denom` buffer and
3910        // computes grad_input on-device. NO `.cpu()` round trip (R-CODE-4).
3911        if self.input.is_cuda() && is_f32::<T>() {
3912            if let (Some(backend), Some(denom_gpu)) = (gpu_backend(), self.denom_gpu.as_ref()) {
3913                let shape = self.input.shape();
3914                let batch = shape[0];
3915                let channels = shape[1];
3916                let spatial: usize = shape[2..].iter().product();
3917                let gi_h = backend.local_response_norm_backward_f32(
3918                    self.input.gpu_handle()?,
3919                    grad_output.gpu_handle()?,
3920                    denom_gpu.gpu_handle()?,
3921                    batch,
3922                    channels,
3923                    spatial,
3924                    self.size,
3925                    self.alpha as f32,
3926                    self.beta as f32,
3927                )?;
3928                let grad_input =
3929                    Tensor::from_storage(TensorStorage::gpu(gi_h), shape.to_vec(), false)?;
3930                return Ok(vec![Some(grad_input)]);
3931            }
3932            return Err(FerrotorchError::NotImplementedOnCuda {
3933                op: "LocalResponseNormBackward",
3934            });
3935        }
3936        if self.input.is_cuda() {
3937            return Err(FerrotorchError::NotImplementedOnCuda {
3938                op: "LocalResponseNormBackward",
3939            });
3940        }
3941
3942        let shape = self.input.shape();
3943        let batch = shape[0];
3944        let channels = shape[1];
3945        let spatial: usize = shape[2..].iter().product();
3946
3947        let input_data = self.input.data()?;
3948        let go_data = grad_output.data()?;
3949        let alpha_t = T::from(self.alpha).unwrap();
3950        let beta_t = T::from(self.beta).unwrap();
3951        let size_t = T::from(self.size).unwrap();
3952        let two = T::from(2.0).unwrap();
3953        // Forward window for output `c` is `[c - half, c + upper)` (width `size`).
3954        // Backward iterates over all output channels `c` whose forward window
3955        // included input channel `i_c`, i.e.
3956        //   c - half <= i_c < c + upper
3957        //   <=> c \in [i_c - upper + 1, i_c + half + 1)
3958        // For odd `size` this collapses to the symmetric `[i_c - half, i_c + half + 1)`
3959        // (same as the legacy code); for even `size` it shifts the window by 1
3960        // channel to match the asymmetric forward pad.
3961        let half = self.size / 2;
3962        let upper = self.size - half; // == (size + 1) / 2
3963
3964        let mut grad_input = vec![zero::<T>(); self.input.numel()];
3965
3966        for b in 0..batch {
3967            for i_c in 0..channels {
3968                for s in 0..spatial {
3969                    let i_idx = b * channels * spatial + i_c * spatial + s;
3970
3971                    // Term 1: D_i^(-beta) * grad_output
3972                    let term1 = self.denom[i_idx].powf(-beta_t) * go_data[i_idx];
3973
3974                    // Term 2: cross-channel interaction
3975                    // For each channel c whose window includes i_c:
3976                    // contribution = -2*beta*alpha/size * x_i * x_c * D_c^(-beta-1) * go_c
3977                    let c_start = (i_c + 1).saturating_sub(upper);
3978                    let c_end = (i_c + half + 1).min(channels);
3979
3980                    let mut cross_sum = zero::<T>();
3981                    for c in c_start..c_end {
3982                        let c_idx = b * channels * spatial + c * spatial + s;
3983                        cross_sum += go_data[c_idx]
3984                            * input_data[c_idx]
3985                            * self.denom[c_idx].powf(-beta_t - T::from(1.0).unwrap());
3986                    }
3987
3988                    grad_input[i_idx] =
3989                        term1 - two * beta_t * alpha_t / size_t * input_data[i_idx] * cross_sum;
3990                }
3991            }
3992        }
3993
3994        let grad_tensor = Tensor::from_storage(
3995            TensorStorage::cpu(grad_input),
3996            self.input.shape().to_vec(),
3997            false,
3998        )?;
3999        Ok(vec![Some(grad_tensor)])
4000    }
4001
4002    fn inputs(&self) -> Vec<&Tensor<T>> {
4003        vec![&self.input]
4004    }
4005
4006    fn name(&self) -> &'static str {
4007        "LocalResponseNormBackward"
4008    }
4009}
4010
4011// ===========================================================================
4012// InstanceNorm — CL-315
4013// ===========================================================================
4014
4015/// Instance normalization: normalizes each **(batch, channel)** slice
4016/// independently, i.e. statistics are computed over the spatial dimensions
4017/// only — never across the batch or across channels.
4018///
4019/// This is equivalent to `GroupNorm` with `num_groups == num_channels`, but
4020/// semantically emphasised as a per-instance, per-channel operation.
4021///
4022/// Unlike `BatchNorm`, `InstanceNorm` does **not** maintain running
4023/// statistics, so its behaviour is identical in train and eval modes.
4024///
4025/// The generic `InstanceNorm<T>` is the shared engine; the public type
4026/// aliases `InstanceNorm1d`, `InstanceNorm2d`, `InstanceNorm3d` simply
4027/// validate that the input tensor has the expected number of dimensions.
4028/// Internal engine shared by `InstanceNorm1d/2d/3d`.
4029#[derive(Debug)]
4030struct InstanceNormInner<T: Float> {
4031    /// Number of channels (features) `C`.
4032    num_features: usize,
4033    /// Small constant for numerical stability.
4034    eps: f64,
4035    /// Whether to apply learnable affine parameters.
4036    affine: bool,
4037    /// Learnable scale (gamma), shape `[C]`.
4038    weight: Parameter<T>,
4039    /// Learnable shift (beta), shape `[C]`.
4040    bias: Parameter<T>,
4041    training: bool,
4042}
4043
4044impl<T: Float> InstanceNormInner<T> {
4045    fn new(num_features: usize, eps: f64, affine: bool) -> FerrotorchResult<Self> {
4046        if num_features == 0 {
4047            return Err(FerrotorchError::InvalidArgument {
4048                message: "InstanceNorm: num_features must be positive".into(),
4049            });
4050        }
4051
4052        let weight = Parameter::ones(&[num_features])?;
4053        let bias = Parameter::zeros(&[num_features])?;
4054
4055        Ok(Self {
4056            num_features,
4057            eps,
4058            affine,
4059            weight,
4060            bias,
4061            training: true,
4062        })
4063    }
4064
4065    /// Forward for input of shape `[B, C, *spatial]`.
4066    /// `expected_ndim` is used only for error messages (3 = 1d, 4 = 2d, 5 = 3d).
4067    fn forward_impl(&self, input: &Tensor<T>, expected_ndim: usize) -> FerrotorchResult<Tensor<T>> {
4068        let label = match expected_ndim {
4069            3 => "InstanceNorm1d",
4070            4 => "InstanceNorm2d",
4071            _ => "InstanceNorm3d",
4072        };
4073        let shape = input.shape().to_vec();
4074
4075        if shape.len() != expected_ndim {
4076            return Err(FerrotorchError::ShapeMismatch {
4077                message: format!("{label}: expected {expected_ndim}D input, got {:?}", shape),
4078            });
4079        }
4080
4081        let batch = shape[0];
4082        let channels = shape[1];
4083        if channels != self.num_features {
4084            return Err(FerrotorchError::ShapeMismatch {
4085                message: format!(
4086                    "{label}: expected {} channels, got {}",
4087                    self.num_features, channels
4088                ),
4089            });
4090        }
4091
4092        let spatial: usize = shape[2..].iter().product();
4093        if spatial == 0 {
4094            return Ok(input.clone());
4095        }
4096
4097        // GPU fast path (#1449): InstanceNorm is exactly GroupNorm with
4098        // `num_groups == num_channels` — each (batch, channel) slice is its
4099        // own normalization group over the spatial dims, with a per-channel
4100        // affine. `weight`/`bias` always have length `num_features` (ones /
4101        // zeros when `affine == false`, so the kernel's unconditional affine
4102        // is the identity). Mirrors `torch/nn/functional.py::instance_norm`
4103        // which lowers to the group-norm reduction for the per-instance case.
4104        if input.is_cuda() {
4105            if let Some(backend) = gpu_backend() {
4106                let eps_f32 = self.eps as f32;
4107                let handle = backend.group_norm_f32(
4108                    input.gpu_handle()?,
4109                    self.weight.tensor().gpu_handle()?,
4110                    self.bias.tensor().gpu_handle()?,
4111                    batch,
4112                    channels,
4113                    channels, // num_groups == num_channels ⇒ InstanceNorm
4114                    spatial,
4115                    eps_f32,
4116                )?;
4117                return if is_grad_enabled() && input.requires_grad() {
4118                    let grad_fn = Arc::new(InstanceNormBackward {
4119                        input: input.clone(),
4120                        weight: self.weight.tensor().clone(),
4121                        bias: self.bias.tensor().clone(),
4122                        num_features: self.num_features,
4123                        eps: self.eps,
4124                        affine: self.affine,
4125                    });
4126                    Tensor::from_operation(TensorStorage::gpu(handle), shape, grad_fn)
4127                } else {
4128                    Tensor::from_storage(TensorStorage::gpu(handle), shape, false)
4129                };
4130            }
4131            // CUDA input without a registered GPU backend: reject honestly.
4132            return Err(FerrotorchError::NotImplementedOnCuda {
4133                op: "InstanceNorm::forward",
4134            });
4135        }
4136        let input_data = input.data()?;
4137        let eps_t = T::from(self.eps).unwrap();
4138        let n_t = T::from(spatial).unwrap();
4139
4140        let weight_data = self.weight.tensor().data()?;
4141        let bias_data = self.bias.tensor().data()?;
4142
4143        let mut output = vec![zero::<T>(); input.numel()];
4144
4145        for b in 0..batch {
4146            for c in 0..channels {
4147                let base = b * channels * spatial + c * spatial;
4148                let slice = &input_data[base..base + spatial];
4149
4150                // Compute mean and variance over spatial dims for this (b, c).
4151                let mean = slice.iter().copied().fold(zero::<T>(), |a, x| a + x) / n_t;
4152                let var = slice.iter().copied().fold(zero::<T>(), |a, x| {
4153                    let d = x - mean;
4154                    a + d * d
4155                }) / n_t;
4156                let inv_std = (var + eps_t).sqrt().recip();
4157
4158                for s in 0..spatial {
4159                    let idx = base + s;
4160                    let normed = (input_data[idx] - mean) * inv_std;
4161                    if self.affine {
4162                        output[idx] = normed * weight_data[c] + bias_data[c];
4163                    } else {
4164                        output[idx] = normed;
4165                    }
4166                }
4167            }
4168        }
4169
4170        let result = Tensor::from_storage(TensorStorage::cpu(output), shape.to_vec(), false)?;
4171
4172        if is_grad_enabled() && input.requires_grad() {
4173            let grad_fn = Arc::new(InstanceNormBackward {
4174                input: input.clone(),
4175                weight: self.weight.tensor().clone(),
4176                bias: self.bias.tensor().clone(),
4177                num_features: self.num_features,
4178                eps: self.eps,
4179                affine: self.affine,
4180            });
4181            Tensor::from_operation(
4182                TensorStorage::cpu(result.data()?.to_vec()),
4183                result.shape().to_vec(),
4184                grad_fn,
4185            )
4186        } else {
4187            Ok(result)
4188        }
4189    }
4190}
4191
4192// ---------------------------------------------------------------------------
4193// InstanceNormBackward
4194// ---------------------------------------------------------------------------
4195
4196/// Backward node for InstanceNorm.
4197///
4198/// Same VJP as GroupNorm / LayerNorm, but the normalization group is
4199/// a single **(batch, channel)** slice over spatial dims.
4200#[derive(Debug)]
4201struct InstanceNormBackward<T: Float> {
4202    input: Tensor<T>,
4203    weight: Tensor<T>,
4204    bias: Tensor<T>,
4205    num_features: usize,
4206    eps: f64,
4207    affine: bool,
4208}
4209
4210impl<T: Float> GradFn<T> for InstanceNormBackward<T> {
4211    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
4212        let shape = self.input.shape();
4213        let batch = shape[0];
4214        let channels = shape[1];
4215        let spatial: usize = shape[2..].iter().product();
4216        let n_t = T::from(spatial).unwrap();
4217        let eps_t = T::from(self.eps).unwrap();
4218
4219        // GPU-native backward (#1449): InstanceNorm is BatchNorm applied
4220        // per-instance. Reshape `[B, C, S]` → `[1, B*C, S]` so each (b, c)
4221        // becomes its own normalization "channel" reduced over spatial only
4222        // (InstanceNorm always uses instance stats, i.e. training=true). The
4223        // per-channel affine `weight[c]` is tiled to `[B*C]`; grad_weight /
4224        // grad_bias come back `[B*C]` and are summed over the batch axis
4225        // on-device (`sum_axis_f32`) to `[C]`. grad_input keeps `[B, C, S]`.
4226        // The full grad-data path stays on-GPU — NO `.cpu()` round trip
4227        // (R-CODE-4). f32-only.
4228        if self.input.is_cuda() && is_f32::<T>() {
4229            if let Some(grads) = instance_norm_gpu_backward(
4230                &self.input,
4231                grad_output,
4232                &self.weight,
4233                batch,
4234                channels,
4235                spatial,
4236                self.eps,
4237                self.affine,
4238                self.weight.requires_grad(),
4239                self.bias.requires_grad(),
4240            )? {
4241                return Ok(grads);
4242            }
4243            return Err(FerrotorchError::NotImplementedOnCuda {
4244                op: "InstanceNormBackward",
4245            });
4246        }
4247        if self.input.is_cuda() {
4248            return Err(FerrotorchError::NotImplementedOnCuda {
4249                op: "InstanceNormBackward",
4250            });
4251        }
4252        let input_data = self.input.data()?;
4253        let go_data = grad_output.data()?;
4254        let weight_data = self.weight.data()?;
4255
4256        let mut grad_input = vec![zero::<T>(); self.input.numel()];
4257        let mut grad_weight = vec![zero::<T>(); self.num_features];
4258        let mut grad_bias = vec![zero::<T>(); self.num_features];
4259
4260        for b in 0..batch {
4261            for c in 0..channels {
4262                let base = b * channels * spatial + c * spatial;
4263                let x_slice = &input_data[base..base + spatial];
4264                let go_slice = &go_data[base..base + spatial];
4265
4266                // Recompute mean and inv_std for this (b, c).
4267                let mean = x_slice.iter().copied().fold(zero::<T>(), |a, x| a + x) / n_t;
4268                let var = x_slice.iter().copied().fold(zero::<T>(), |a, x| {
4269                    let d = x - mean;
4270                    a + d * d
4271                }) / n_t;
4272                let inv_std = (var + eps_t).sqrt().recip();
4273
4274                // Accumulate sums for the VJP.
4275                let mut dl_dx_hat_sum = zero::<T>();
4276                let mut dl_dx_hat_x_hat_sum = zero::<T>();
4277
4278                for s in 0..spatial {
4279                    let x_hat = (x_slice[s] - mean) * inv_std;
4280                    let dl_dx_hat = if self.affine {
4281                        go_slice[s] * weight_data[c]
4282                    } else {
4283                        go_slice[s]
4284                    };
4285                    dl_dx_hat_sum += dl_dx_hat;
4286                    dl_dx_hat_x_hat_sum += dl_dx_hat * x_hat;
4287
4288                    if self.affine {
4289                        grad_weight[c] += go_slice[s] * x_hat;
4290                        grad_bias[c] += go_slice[s];
4291                    }
4292                }
4293
4294                let dl_dx_hat_mean = dl_dx_hat_sum / n_t;
4295                let dl_dx_hat_x_hat_mean = dl_dx_hat_x_hat_sum / n_t;
4296
4297                for s in 0..spatial {
4298                    let x_hat = (x_slice[s] - mean) * inv_std;
4299                    let dl_dx_hat = if self.affine {
4300                        go_slice[s] * weight_data[c]
4301                    } else {
4302                        go_slice[s]
4303                    };
4304                    grad_input[base + s] =
4305                        inv_std * (dl_dx_hat - dl_dx_hat_mean - x_hat * dl_dx_hat_x_hat_mean);
4306                }
4307            }
4308        }
4309
4310        let grad_input_tensor = Tensor::from_storage(
4311            TensorStorage::cpu(grad_input),
4312            self.input.shape().to_vec(),
4313            false,
4314        )?;
4315
4316        let grad_weight_out = if self.affine && self.weight.requires_grad() {
4317            Some(Tensor::from_storage(
4318                TensorStorage::cpu(grad_weight),
4319                vec![self.num_features],
4320                false,
4321            )?)
4322        } else {
4323            None
4324        };
4325
4326        let grad_bias_out = if self.affine && self.bias.requires_grad() {
4327            Some(Tensor::from_storage(
4328                TensorStorage::cpu(grad_bias),
4329                vec![self.num_features],
4330                false,
4331            )?)
4332        } else {
4333            None
4334        };
4335
4336        Ok(vec![
4337            Some(grad_input_tensor),
4338            grad_weight_out,
4339            grad_bias_out,
4340        ])
4341    }
4342
4343    fn inputs(&self) -> Vec<&Tensor<T>> {
4344        vec![&self.input, &self.weight, &self.bias]
4345    }
4346
4347    fn name(&self) -> &'static str {
4348        "InstanceNormBackward"
4349    }
4350}
4351
4352// ---------------------------------------------------------------------------
4353// InstanceNorm1d — CL-315
4354// ---------------------------------------------------------------------------
4355
4356/// Instance normalization for 3D input `[N, C, L]`.
4357///
4358/// Normalizes each `(n, c)` slice independently over the `L` dimension.
4359/// No running statistics are maintained.
4360///
4361/// Matches `torch.nn.InstanceNorm1d`.
4362#[derive(Debug)]
4363pub struct InstanceNorm1d<T: Float> {
4364    inner: InstanceNormInner<T>,
4365}
4366
4367impl<T: Float> InstanceNorm1d<T> {
4368    /// Create a new `InstanceNorm1d` layer.
4369    ///
4370    /// # Arguments
4371    ///
4372    /// * `num_features` - Number of channels `C`.
4373    /// * `eps` - Numerical stability constant (default: `1e-5`).
4374    /// * `affine` - Whether to include learnable weight and bias.
4375    pub fn new(num_features: usize, eps: f64, affine: bool) -> FerrotorchResult<Self> {
4376        Ok(Self {
4377            inner: InstanceNormInner::new(num_features, eps, affine)?,
4378        })
4379    }
4380}
4381
4382impl<T: Float> Module<T> for InstanceNorm1d<T> {
4383    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
4384        self.inner.forward_impl(input, 3)
4385    }
4386
4387    fn parameters(&self) -> Vec<&Parameter<T>> {
4388        if self.inner.affine {
4389            vec![&self.inner.weight, &self.inner.bias]
4390        } else {
4391            vec![]
4392        }
4393    }
4394
4395    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4396        if self.inner.affine {
4397            vec![&mut self.inner.weight, &mut self.inner.bias]
4398        } else {
4399            vec![]
4400        }
4401    }
4402
4403    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4404        if self.inner.affine {
4405            vec![
4406                ("weight".to_string(), &self.inner.weight),
4407                ("bias".to_string(), &self.inner.bias),
4408            ]
4409        } else {
4410            vec![]
4411        }
4412    }
4413
4414    fn train(&mut self) {
4415        self.inner.training = true;
4416    }
4417
4418    fn eval(&mut self) {
4419        self.inner.training = false;
4420    }
4421
4422    fn is_training(&self) -> bool {
4423        self.inner.training
4424    }
4425}
4426
4427// ---------------------------------------------------------------------------
4428// InstanceNorm2d — CL-315
4429// ---------------------------------------------------------------------------
4430
4431/// Instance normalization for 4D input `[N, C, H, W]`.
4432///
4433/// Normalizes each `(n, c)` slice independently over the `(H, W)` dimensions.
4434/// No running statistics are maintained.
4435///
4436/// Matches `torch.nn.InstanceNorm2d`.
4437#[derive(Debug)]
4438pub struct InstanceNorm2d<T: Float> {
4439    inner: InstanceNormInner<T>,
4440}
4441
4442impl<T: Float> InstanceNorm2d<T> {
4443    /// Create a new `InstanceNorm2d` layer.
4444    ///
4445    /// # Arguments
4446    ///
4447    /// * `num_features` - Number of channels `C`.
4448    /// * `eps` - Numerical stability constant (default: `1e-5`).
4449    /// * `affine` - Whether to include learnable weight and bias.
4450    pub fn new(num_features: usize, eps: f64, affine: bool) -> FerrotorchResult<Self> {
4451        Ok(Self {
4452            inner: InstanceNormInner::new(num_features, eps, affine)?,
4453        })
4454    }
4455}
4456
4457impl<T: Float> Module<T> for InstanceNorm2d<T> {
4458    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
4459        self.inner.forward_impl(input, 4)
4460    }
4461
4462    fn parameters(&self) -> Vec<&Parameter<T>> {
4463        if self.inner.affine {
4464            vec![&self.inner.weight, &self.inner.bias]
4465        } else {
4466            vec![]
4467        }
4468    }
4469
4470    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4471        if self.inner.affine {
4472            vec![&mut self.inner.weight, &mut self.inner.bias]
4473        } else {
4474            vec![]
4475        }
4476    }
4477
4478    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4479        if self.inner.affine {
4480            vec![
4481                ("weight".to_string(), &self.inner.weight),
4482                ("bias".to_string(), &self.inner.bias),
4483            ]
4484        } else {
4485            vec![]
4486        }
4487    }
4488
4489    fn train(&mut self) {
4490        self.inner.training = true;
4491    }
4492
4493    fn eval(&mut self) {
4494        self.inner.training = false;
4495    }
4496
4497    fn is_training(&self) -> bool {
4498        self.inner.training
4499    }
4500}
4501
4502// ---------------------------------------------------------------------------
4503// InstanceNorm3d — CL-315
4504// ---------------------------------------------------------------------------
4505
4506/// Instance normalization for 5D input `[N, C, D, H, W]`.
4507///
4508/// Normalizes each `(n, c)` slice independently over the `(D, H, W)` dims.
4509/// No running statistics are maintained.
4510///
4511/// Matches `torch.nn.InstanceNorm3d`.
4512#[derive(Debug)]
4513pub struct InstanceNorm3d<T: Float> {
4514    inner: InstanceNormInner<T>,
4515}
4516
4517impl<T: Float> InstanceNorm3d<T> {
4518    /// Create a new `InstanceNorm3d` layer.
4519    ///
4520    /// # Arguments
4521    ///
4522    /// * `num_features` - Number of channels `C`.
4523    /// * `eps` - Numerical stability constant (default: `1e-5`).
4524    /// * `affine` - Whether to include learnable weight and bias.
4525    pub fn new(num_features: usize, eps: f64, affine: bool) -> FerrotorchResult<Self> {
4526        Ok(Self {
4527            inner: InstanceNormInner::new(num_features, eps, affine)?,
4528        })
4529    }
4530}
4531
4532impl<T: Float> Module<T> for InstanceNorm3d<T> {
4533    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
4534        self.inner.forward_impl(input, 5)
4535    }
4536
4537    fn parameters(&self) -> Vec<&Parameter<T>> {
4538        if self.inner.affine {
4539            vec![&self.inner.weight, &self.inner.bias]
4540        } else {
4541            vec![]
4542        }
4543    }
4544
4545    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4546        if self.inner.affine {
4547            vec![&mut self.inner.weight, &mut self.inner.bias]
4548        } else {
4549            vec![]
4550        }
4551    }
4552
4553    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4554        if self.inner.affine {
4555            vec![
4556                ("weight".to_string(), &self.inner.weight),
4557                ("bias".to_string(), &self.inner.bias),
4558            ]
4559        } else {
4560            vec![]
4561        }
4562    }
4563
4564    fn train(&mut self) {
4565        self.inner.training = true;
4566    }
4567
4568    fn eval(&mut self) {
4569        self.inner.training = false;
4570    }
4571
4572    fn is_training(&self) -> bool {
4573        self.inner.training
4574    }
4575}
4576
4577// ===========================================================================
4578// Tests
4579// ===========================================================================
4580
4581#[cfg(test)]
4582mod tests {
4583    use super::*;
4584    use ferrotorch_core::autograd::no_grad::no_grad;
4585
4586    /// Helper: create a leaf tensor with given data, shape, and requires_grad.
4587    fn leaf(data: &[f64], shape: &[usize], requires_grad: bool) -> Tensor<f64> {
4588        Tensor::from_storage(
4589            TensorStorage::cpu(data.to_vec()),
4590            shape.to_vec(),
4591            requires_grad,
4592        )
4593        .unwrap()
4594    }
4595
4596    // -----------------------------------------------------------------------
4597    // LayerNorm tests
4598    // -----------------------------------------------------------------------
4599
4600    #[test]
4601    fn test_layer_norm_parameter_shapes() {
4602        let ln = LayerNorm::<f32>::new(vec![8], 1e-5, true).unwrap();
4603        let params = ln.parameters();
4604        assert_eq!(params.len(), 2);
4605        assert_eq!(params[0].shape(), &[8]); // weight
4606        assert_eq!(params[1].shape(), &[8]); // bias
4607    }
4608
4609    #[test]
4610    fn test_layer_norm_no_affine_no_params() {
4611        let ln = LayerNorm::<f32>::new(vec![8], 1e-5, false).unwrap();
4612        assert_eq!(ln.parameters().len(), 0);
4613    }
4614
4615    #[test]
4616    fn test_layer_norm_forward_zero_mean_unit_var() {
4617        // After LayerNorm (with default weight=1, bias=0), each row should
4618        // have approximately zero mean and unit variance.
4619        let data: Vec<f32> = vec![
4620            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // row 0
4621            -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, // row 1
4622        ];
4623        let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 8], false).unwrap();
4624
4625        let ln = LayerNorm::<f32>::new(vec![8], 1e-5, true).unwrap();
4626        let output = ln.forward(&input).unwrap();
4627        let out_data = output.data().unwrap();
4628
4629        for row in 0..2 {
4630            let start = row * 8;
4631            let end = start + 8;
4632            let row_data = &out_data[start..end];
4633
4634            let mean: f32 = row_data.iter().sum::<f32>() / 8.0;
4635            let var: f32 = row_data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / 8.0;
4636
4637            assert!(mean.abs() < 1e-5, "row {row} mean = {mean}, expected ~0");
4638            assert!(
4639                (var - 1.0).abs() < 0.05,
4640                "row {row} var = {var}, expected ~1"
4641            );
4642        }
4643    }
4644
4645    #[test]
4646    fn test_layer_norm_forward_shape_preserved() {
4647        let input =
4648            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 24]), vec![2, 3, 4], false)
4649                .unwrap();
4650
4651        let ln = LayerNorm::<f32>::new(vec![4], 1e-5, true).unwrap();
4652        let output = ln.forward(&input).unwrap();
4653        assert_eq!(output.shape(), &[2, 3, 4]);
4654    }
4655
4656    #[test]
4657    fn test_layer_norm_shape_mismatch() {
4658        let input =
4659            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 12]), vec![3, 4], false)
4660                .unwrap();
4661
4662        let ln = LayerNorm::<f32>::new(vec![5], 1e-5, true).unwrap();
4663        assert!(ln.forward(&input).is_err());
4664    }
4665
4666    #[test]
4667    fn test_layer_norm_empty_normalized_shape() {
4668        assert!(LayerNorm::<f32>::new(vec![], 1e-5, true).is_err());
4669    }
4670
4671    #[test]
4672    fn test_layer_norm_has_grad_fn_when_input_requires_grad() {
4673        let input = Tensor::<f32>::from_storage(
4674            TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
4675            vec![1, 4],
4676            true,
4677        )
4678        .unwrap();
4679
4680        let ln = LayerNorm::<f32>::new(vec![4], 1e-5, true).unwrap();
4681        let output = ln.forward(&input).unwrap();
4682        assert!(output.grad_fn().is_some());
4683        assert_eq!(output.grad_fn().unwrap().name(), "LayerNormBackward");
4684    }
4685
4686    #[test]
4687    fn test_layer_norm_no_grad_fn_in_no_grad_context() {
4688        let input = Tensor::<f32>::from_storage(
4689            TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
4690            vec![1, 4],
4691            true,
4692        )
4693        .unwrap();
4694
4695        let ln = LayerNorm::<f32>::new(vec![4], 1e-5, true).unwrap();
4696        let output = no_grad(|| ln.forward(&input)).unwrap();
4697        assert!(output.grad_fn().is_none());
4698    }
4699
4700    #[test]
4701    fn test_layer_norm_backward_gradient_check() -> FerrotorchResult<()> {
4702        // Numerical gradient check for LayerNorm on a small input.
4703        // Use f64 for better precision.
4704        let h = 1e-7;
4705        let hidden = 4;
4706        let input_data = vec![1.0f64, -0.5, 2.0, 0.3];
4707
4708        let ln = LayerNorm::<f64>::new(vec![hidden], 1e-5, true)?;
4709
4710        // Forward and backward.
4711        let input = leaf(&input_data, &[1, hidden], true);
4712        let output = ln.forward(&input)?;
4713        let out_data = output.data()?.to_vec();
4714        let total: f64 = out_data.iter().sum();
4715
4716        // Build sum backward manually.
4717        let sum_gf = Arc::new(SumBackwardHelper {
4718            input: output.clone(),
4719        });
4720        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf)?;
4721        loss.backward()?;
4722
4723        let analytic_grad = input.grad().unwrap().unwrap();
4724        let analytic = analytic_grad.data()?.to_vec();
4725
4726        // Numerical gradient.
4727        for i in 0..hidden {
4728            let mut data_plus = input_data.clone();
4729            data_plus[i] += h;
4730            let inp_plus = leaf(&data_plus, &[1, hidden], false);
4731            let out_plus = no_grad(|| ln.forward(&inp_plus)).unwrap();
4732            let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
4733
4734            let mut data_minus = input_data.clone();
4735            data_minus[i] -= h;
4736            let inp_minus = leaf(&data_minus, &[1, hidden], false);
4737            let out_minus = no_grad(|| ln.forward(&inp_minus)).unwrap();
4738            let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
4739
4740            let numerical = (sum_plus - sum_minus) / (2.0 * h);
4741            assert!(
4742                (numerical - analytic[i]).abs() < 1e-4,
4743                "LayerNorm grad[{i}]: numerical={numerical}, analytic={}",
4744                analytic[i]
4745            );
4746        }
4747
4748        Ok(())
4749    }
4750
4751    #[test]
4752    fn test_layer_norm_named_parameters() {
4753        let ln = LayerNorm::<f32>::new(vec![16], 1e-5, true).unwrap();
4754        let named = ln.named_parameters();
4755        assert_eq!(named.len(), 2);
4756        assert_eq!(named[0].0, "weight");
4757        assert_eq!(named[1].0, "bias");
4758    }
4759
4760    #[test]
4761    fn test_layer_norm_train_eval() {
4762        let mut ln = LayerNorm::<f32>::new(vec![8], 1e-5, true).unwrap();
4763        assert!(ln.is_training());
4764        ln.eval();
4765        assert!(!ln.is_training());
4766        ln.train();
4767        assert!(ln.is_training());
4768    }
4769
4770    // -----------------------------------------------------------------------
4771    // GroupNorm tests
4772    // -----------------------------------------------------------------------
4773
4774    #[test]
4775    fn test_group_norm_parameter_shapes() {
4776        let gn = GroupNorm::<f32>::new(4, 8, 1e-5, true).unwrap();
4777        let params = gn.parameters();
4778        assert_eq!(params.len(), 2);
4779        assert_eq!(params[0].shape(), &[8]); // weight
4780        assert_eq!(params[1].shape(), &[8]); // bias
4781    }
4782
4783    #[test]
4784    fn test_group_norm_no_affine_no_params() {
4785        let gn = GroupNorm::<f32>::new(2, 4, 1e-5, false).unwrap();
4786        assert_eq!(gn.parameters().len(), 0);
4787    }
4788
4789    #[test]
4790    fn test_group_norm_invalid_groups() {
4791        assert!(GroupNorm::<f32>::new(0, 8, 1e-5, true).is_err());
4792        assert!(GroupNorm::<f32>::new(3, 8, 1e-5, true).is_err()); // 8 not divisible by 3
4793    }
4794
4795    #[test]
4796    fn test_group_norm_forward_zero_mean_unit_var() {
4797        // With groups=2, channels=4: groups are [0,1] and [2,3].
4798        // Each group should be normalized to ~zero mean, ~unit var.
4799        let data: Vec<f32> = vec![
4800            // batch=0, channel 0: 1, 2; channel 1: 3, 4; channel 2: 5, 6; channel 3: 7, 8
4801            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
4802        ];
4803        // Shape: [1, 4, 2] (B=1, C=4, spatial=2)
4804        let input = Tensor::from_storage(TensorStorage::cpu(data), vec![1, 4, 2], false).unwrap();
4805
4806        let gn = GroupNorm::<f32>::new(2, 4, 1e-5, true).unwrap();
4807        let output = gn.forward(&input).unwrap();
4808        let out_data = output.data().unwrap();
4809
4810        // Group 0: channels 0,1 -> indices [0,1,2,3] -> values were [1,2,3,4]
4811        let group0: Vec<f32> = out_data[0..4].to_vec();
4812        let mean0: f32 = group0.iter().sum::<f32>() / 4.0;
4813        let var0: f32 = group0.iter().map(|&x| (x - mean0).powi(2)).sum::<f32>() / 4.0;
4814        assert!(mean0.abs() < 1e-5, "group0 mean = {mean0}");
4815        assert!((var0 - 1.0).abs() < 0.05, "group0 var = {var0}");
4816
4817        // Group 1: channels 2,3 -> indices [4,5,6,7] -> values were [5,6,7,8]
4818        let group1: Vec<f32> = out_data[4..8].to_vec();
4819        let mean1: f32 = group1.iter().sum::<f32>() / 4.0;
4820        let var1: f32 = group1.iter().map(|&x| (x - mean1).powi(2)).sum::<f32>() / 4.0;
4821        assert!(mean1.abs() < 1e-5, "group1 mean = {mean1}");
4822        assert!((var1 - 1.0).abs() < 0.05, "group1 var = {var1}");
4823    }
4824
4825    #[test]
4826    fn test_group_norm_forward_shape_preserved() {
4827        let input =
4828            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 48]), vec![2, 4, 6], false)
4829                .unwrap();
4830
4831        let gn = GroupNorm::<f32>::new(2, 4, 1e-5, true).unwrap();
4832        let output = gn.forward(&input).unwrap();
4833        assert_eq!(output.shape(), &[2, 4, 6]);
4834    }
4835
4836    #[test]
4837    fn test_group_norm_channel_mismatch() {
4838        let input =
4839            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 24]), vec![2, 3, 4], false)
4840                .unwrap();
4841
4842        let gn = GroupNorm::<f32>::new(2, 4, 1e-5, true).unwrap();
4843        assert!(gn.forward(&input).is_err());
4844    }
4845
4846    #[test]
4847    fn test_group_norm_has_grad_fn() {
4848        let input =
4849            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 8]), vec![1, 4, 2], true)
4850                .unwrap();
4851
4852        let gn = GroupNorm::<f32>::new(2, 4, 1e-5, true).unwrap();
4853        let output = gn.forward(&input).unwrap();
4854        assert!(output.grad_fn().is_some());
4855        assert_eq!(output.grad_fn().unwrap().name(), "GroupNormBackward");
4856    }
4857
4858    #[test]
4859    fn test_group_norm_backward_gradient_check() -> FerrotorchResult<()> {
4860        let h = 1e-7;
4861        // Shape: [1, 4, 2], groups=2
4862        let input_data = vec![1.0f64, -0.5, 2.0, 0.3, -1.0, 0.7, 1.5, -0.2];
4863        let gn = GroupNorm::<f64>::new(2, 4, 1e-5, true)?;
4864
4865        let input = leaf(&input_data, &[1, 4, 2], true);
4866        let output = gn.forward(&input)?;
4867        let out_data = output.data()?.to_vec();
4868        let total: f64 = out_data.iter().sum();
4869
4870        let sum_gf = Arc::new(SumBackwardHelper {
4871            input: output.clone(),
4872        });
4873        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf)?;
4874        loss.backward()?;
4875
4876        let analytic_grad = input.grad().unwrap().unwrap();
4877        let analytic = analytic_grad.data()?.to_vec();
4878
4879        for i in 0..8 {
4880            let mut data_plus = input_data.clone();
4881            data_plus[i] += h;
4882            let inp_plus = leaf(&data_plus, &[1, 4, 2], false);
4883            let out_plus = no_grad(|| gn.forward(&inp_plus)).unwrap();
4884            let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
4885
4886            let mut data_minus = input_data.clone();
4887            data_minus[i] -= h;
4888            let inp_minus = leaf(&data_minus, &[1, 4, 2], false);
4889            let out_minus = no_grad(|| gn.forward(&inp_minus)).unwrap();
4890            let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
4891
4892            let numerical = (sum_plus - sum_minus) / (2.0 * h);
4893            assert!(
4894                (numerical - analytic[i]).abs() < 1e-4,
4895                "GroupNorm grad[{i}]: numerical={numerical}, analytic={}",
4896                analytic[i]
4897            );
4898        }
4899
4900        Ok(())
4901    }
4902
4903    #[test]
4904    fn test_group_norm_named_parameters() {
4905        let gn = GroupNorm::<f32>::new(2, 8, 1e-5, true).unwrap();
4906        let named = gn.named_parameters();
4907        assert_eq!(named.len(), 2);
4908        assert_eq!(named[0].0, "weight");
4909        assert_eq!(named[1].0, "bias");
4910    }
4911
4912    // -----------------------------------------------------------------------
4913    // RMSNorm tests
4914    // -----------------------------------------------------------------------
4915
4916    #[test]
4917    fn test_rms_norm_parameter_shapes() {
4918        let rn = RMSNorm::<f32>::new(vec![8], 1e-5).unwrap();
4919        let params = rn.parameters();
4920        assert_eq!(params.len(), 1);
4921        assert_eq!(params[0].shape(), &[8]); // weight only
4922    }
4923
4924    #[test]
4925    fn test_rms_norm_forward_scale() {
4926        // After RMSNorm (with weight=1), the RMS of each row should be ~1.
4927        let data: Vec<f32> = vec![
4928            1.0, 2.0, 3.0, 4.0, // row 0
4929            -1.0, 0.5, 2.0, -3.0, // row 1
4930        ];
4931        let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 4], false).unwrap();
4932
4933        let rn = RMSNorm::<f32>::new(vec![4], 1e-5).unwrap();
4934        let output = rn.forward(&input).unwrap();
4935        let out_data = output.data().unwrap();
4936
4937        for row in 0..2 {
4938            let start = row * 4;
4939            let end = start + 4;
4940            let row_data = &out_data[start..end];
4941
4942            let mean_sq: f32 = row_data.iter().map(|x| x * x).sum::<f32>() / 4.0;
4943            let rms = mean_sq.sqrt();
4944
4945            assert!(
4946                (rms - 1.0).abs() < 0.05,
4947                "row {row} RMS = {rms}, expected ~1"
4948            );
4949        }
4950    }
4951
4952    #[test]
4953    fn test_rms_norm_forward_shape_preserved() {
4954        let input =
4955            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0; 24]), vec![2, 3, 4], false)
4956                .unwrap();
4957
4958        let rn = RMSNorm::<f32>::new(vec![4], 1e-5).unwrap();
4959        let output = rn.forward(&input).unwrap();
4960        assert_eq!(output.shape(), &[2, 3, 4]);
4961    }
4962
4963    #[test]
4964    fn test_rms_norm_empty_normalized_shape() {
4965        assert!(RMSNorm::<f32>::new(vec![], 1e-5).is_err());
4966    }
4967
4968    #[test]
4969    fn test_rms_norm_has_grad_fn() {
4970        let input = Tensor::<f32>::from_storage(
4971            TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
4972            vec![1, 4],
4973            true,
4974        )
4975        .unwrap();
4976
4977        let rn = RMSNorm::<f32>::new(vec![4], 1e-5).unwrap();
4978        let output = rn.forward(&input).unwrap();
4979        assert!(output.grad_fn().is_some());
4980        assert_eq!(output.grad_fn().unwrap().name(), "RMSNormBackward");
4981    }
4982
4983    #[test]
4984    fn test_rms_norm_no_grad_fn_in_no_grad_context() {
4985        let input = Tensor::<f32>::from_storage(
4986            TensorStorage::cpu(vec![1.0, 2.0, 3.0, 4.0]),
4987            vec![1, 4],
4988            true,
4989        )
4990        .unwrap();
4991
4992        let rn = RMSNorm::<f32>::new(vec![4], 1e-5).unwrap();
4993        let output = no_grad(|| rn.forward(&input)).unwrap();
4994        assert!(output.grad_fn().is_none());
4995    }
4996
4997    #[test]
4998    fn test_rms_norm_backward_gradient_check() -> FerrotorchResult<()> {
4999        let h = 1e-7;
5000        let hidden = 4;
5001        let input_data = vec![1.0f64, -0.5, 2.0, 0.3];
5002
5003        let rn = RMSNorm::<f64>::new(vec![hidden], 1e-5)?;
5004
5005        let input = leaf(&input_data, &[1, hidden], true);
5006        let output = rn.forward(&input)?;
5007        let out_data = output.data()?.to_vec();
5008        let total: f64 = out_data.iter().sum();
5009
5010        let sum_gf = Arc::new(SumBackwardHelper {
5011            input: output.clone(),
5012        });
5013        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf)?;
5014        loss.backward()?;
5015
5016        let analytic_grad = input.grad().unwrap().unwrap();
5017        let analytic = analytic_grad.data()?.to_vec();
5018
5019        for i in 0..hidden {
5020            let mut data_plus = input_data.clone();
5021            data_plus[i] += h;
5022            let inp_plus = leaf(&data_plus, &[1, hidden], false);
5023            let out_plus = no_grad(|| rn.forward(&inp_plus)).unwrap();
5024            let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
5025
5026            let mut data_minus = input_data.clone();
5027            data_minus[i] -= h;
5028            let inp_minus = leaf(&data_minus, &[1, hidden], false);
5029            let out_minus = no_grad(|| rn.forward(&inp_minus)).unwrap();
5030            let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
5031
5032            let numerical = (sum_plus - sum_minus) / (2.0 * h);
5033            assert!(
5034                (numerical - analytic[i]).abs() < 1e-4,
5035                "RMSNorm grad[{i}]: numerical={numerical}, analytic={}",
5036                analytic[i]
5037            );
5038        }
5039
5040        Ok(())
5041    }
5042
5043    #[test]
5044    fn test_rms_norm_named_parameters() {
5045        let rn = RMSNorm::<f32>::new(vec![16], 1e-5).unwrap();
5046        let named = rn.named_parameters();
5047        assert_eq!(named.len(), 1);
5048        assert_eq!(named[0].0, "weight");
5049    }
5050
5051    #[test]
5052    fn test_rms_norm_train_eval() {
5053        let mut rn = RMSNorm::<f32>::new(vec![8], 1e-5).unwrap();
5054        assert!(rn.is_training());
5055        rn.eval();
5056        assert!(!rn.is_training());
5057        rn.train();
5058        assert!(rn.is_training());
5059    }
5060
5061    // -----------------------------------------------------------------------
5062    // BatchNorm2d tests
5063    // -----------------------------------------------------------------------
5064
5065    #[test]
5066    fn test_batch_norm_2d_output_shape() {
5067        let bn = BatchNorm2d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5068        // [B=2, C=3, H=4, W=4]
5069        let input = Tensor::from_storage(
5070            TensorStorage::cpu(vec![1.0f32; 2 * 3 * 4 * 4]),
5071            vec![2, 3, 4, 4],
5072            false,
5073        )
5074        .unwrap();
5075
5076        let output = bn.forward(&input).unwrap();
5077        assert_eq!(output.shape(), &[2, 3, 4, 4]);
5078    }
5079
5080    #[test]
5081    fn test_batch_norm_2d_rejects_non_4d() {
5082        let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5083        // 3D input should fail.
5084        let input =
5085            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 24]), vec![2, 4, 3], false)
5086                .unwrap();
5087        assert!(bn.forward(&input).is_err());
5088    }
5089
5090    #[test]
5091    fn test_batch_norm_2d_channel_mismatch() {
5092        let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5093        let input = Tensor::from_storage(
5094            TensorStorage::cpu(vec![1.0f32; 2 * 3 * 2 * 2]),
5095            vec![2, 3, 2, 2],
5096            false,
5097        )
5098        .unwrap();
5099        assert!(bn.forward(&input).is_err());
5100    }
5101
5102    #[test]
5103    fn test_batch_norm_2d_zero_features() {
5104        assert!(BatchNorm2d::<f32>::new(0, 1e-5, 0.1, true).is_err());
5105    }
5106
5107    #[test]
5108    fn test_batch_norm_2d_training_normalizes() {
5109        // After training-mode BatchNorm2d (weight=1, bias=0), each channel
5110        // should have approximately zero mean and unit variance over (B, H, W).
5111        let channels = 2;
5112        let b = 2;
5113        let h = 3;
5114        let w = 3;
5115        let spatial = h * w;
5116        // Build data: channel 0 has values 1..18, channel 1 has 101..118
5117        let mut data = Vec::new();
5118        for bi in 0..b {
5119            for c in 0..channels {
5120                let offset = c as f32 * 100.0;
5121                for s in 0..spatial {
5122                    data.push(offset + (bi * spatial + s) as f32 + 1.0);
5123                }
5124            }
5125        }
5126        let input =
5127            Tensor::from_storage(TensorStorage::cpu(data), vec![b, channels, h, w], false).unwrap();
5128
5129        let bn = BatchNorm2d::<f32>::new(channels, 1e-5, 0.1, true).unwrap();
5130        let output = bn.forward(&input).unwrap();
5131        let out_data = output.data().unwrap();
5132
5133        for c in 0..channels {
5134            // Gather all values for this channel across (B, H, W).
5135            let mut vals = Vec::new();
5136            for bi in 0..b {
5137                let base = bi * channels * spatial + c * spatial;
5138                for s in 0..spatial {
5139                    vals.push(out_data[base + s]);
5140                }
5141            }
5142            let n = vals.len() as f32;
5143            let mean: f32 = vals.iter().sum::<f32>() / n;
5144            let var: f32 = vals.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / n;
5145
5146            assert!(mean.abs() < 1e-4, "channel {c}: mean = {mean}, expected ~0");
5147            assert!(
5148                (var - 1.0).abs() < 0.1,
5149                "channel {c}: var = {var}, expected ~1"
5150            );
5151        }
5152    }
5153
5154    #[test]
5155    fn test_batch_norm_2d_eval_uses_running_stats() {
5156        let channels = 2;
5157        let b = 4;
5158        let h = 2;
5159        let w = 2;
5160        let spatial = h * w;
5161
5162        // Create layer and run a few training batches to build up running stats.
5163        let bn = BatchNorm2d::<f64>::new(channels, 1e-5, 0.1, true).unwrap();
5164
5165        // Training batch with known data.
5166        let mut data = vec![0.0f64; b * channels * spatial];
5167        for bi in 0..b {
5168            for c in 0..channels {
5169                let base = bi * channels * spatial + c * spatial;
5170                for s in 0..spatial {
5171                    data[base + s] = (c as f64) * 10.0 + (bi * spatial + s) as f64;
5172                }
5173            }
5174        }
5175        let input = Tensor::from_storage(
5176            TensorStorage::cpu(data.clone()),
5177            vec![b, channels, h, w],
5178            false,
5179        )
5180        .unwrap();
5181
5182        // Training forward to update running stats.
5183        let _ = bn.forward(&input).unwrap();
5184        let rm_after_train = bn.running_mean();
5185        let rv_after_train = bn.running_var();
5186
5187        // Running stats should no longer be the initial [0,0] and [1,1].
5188        assert!(
5189            rm_after_train[0].abs() > 1e-6 || rm_after_train[1].abs() > 1e-6,
5190            "running_mean should have been updated"
5191        );
5192
5193        // Switch to eval mode and forward again.
5194        // Use a sneaky mut reference via the training mutex.
5195        *bn.training.lock().unwrap() = false;
5196
5197        let output_eval = bn.forward(&input).unwrap();
5198        let eval_data = output_eval.data().unwrap();
5199
5200        // In eval mode, the output should use running_mean/running_var,
5201        // so per-channel values won't necessarily have zero mean.
5202        // Verify that the output is deterministic and matches
5203        // manual computation using running stats.
5204        for c in 0..channels {
5205            let expected_mean = rm_after_train[c];
5206            let expected_var = rv_after_train[c];
5207            let inv_std = 1.0 / (expected_var + 1e-5).sqrt();
5208
5209            for bi in 0..b {
5210                let base = bi * channels * spatial + c * spatial;
5211                for s in 0..spatial {
5212                    let x = (c as f64) * 10.0 + (bi * spatial + s) as f64;
5213                    let expected = (x - expected_mean) * inv_std;
5214                    // weight=1, bias=0 by default.
5215                    let actual = eval_data[base + s];
5216                    assert!(
5217                        (actual - expected).abs() < 1e-6,
5218                        "eval output mismatch at b={bi}, c={c}, s={s}: actual={actual}, expected={expected}"
5219                    );
5220                }
5221            }
5222        }
5223    }
5224
5225    #[test]
5226    fn test_batch_norm_2d_running_stats_update() {
5227        let channels = 2;
5228        let bn = BatchNorm2d::<f32>::new(channels, 1e-5, 0.1, true).unwrap();
5229
5230        // Initial state.
5231        assert_eq!(bn.running_mean(), vec![0.0, 0.0]);
5232        assert_eq!(bn.running_var(), vec![1.0, 1.0]);
5233        assert_eq!(bn.num_batches_tracked(), 0);
5234
5235        // Forward pass 1.
5236        let input = Tensor::from_storage(
5237            TensorStorage::cpu(vec![1.0f32; 2 * 2 * 2 * 2]),
5238            vec![2, 2, 2, 2],
5239            false,
5240        )
5241        .unwrap();
5242        let _ = bn.forward(&input).unwrap();
5243        assert_eq!(bn.num_batches_tracked(), 1);
5244
5245        let rm = bn.running_mean();
5246        let rv = bn.running_var();
5247        // running_mean = (1-0.1)*0 + 0.1*batch_mean = 0.1*1.0 = 0.1
5248        assert!(
5249            (rm[0] - 0.1).abs() < 1e-5,
5250            "running_mean[0] = {}, expected 0.1",
5251            rm[0]
5252        );
5253        // batch_var = 0 (all values are 1.0), bessel-corrected var = 0
5254        // running_var = (1-0.1)*1.0 + 0.1*0.0 = 0.9
5255        assert!(
5256            (rv[0] - 0.9).abs() < 1e-5,
5257            "running_var[0] = {}, expected 0.9",
5258            rv[0]
5259        );
5260
5261        // Forward pass 2.
5262        let _ = bn.forward(&input).unwrap();
5263        assert_eq!(bn.num_batches_tracked(), 2);
5264    }
5265
5266    #[test]
5267    fn test_batch_norm_2d_affine_parameters() {
5268        let bn = BatchNorm2d::<f32>::new(8, 1e-5, 0.1, true).unwrap();
5269        let params = bn.parameters();
5270        assert_eq!(params.len(), 2);
5271        assert_eq!(params[0].shape(), &[8]); // weight
5272        assert_eq!(params[1].shape(), &[8]); // bias
5273
5274        let named = bn.named_parameters();
5275        assert_eq!(named.len(), 2);
5276        assert_eq!(named[0].0, "weight");
5277        assert_eq!(named[1].0, "bias");
5278
5279        // Weight should be ones, bias should be zeros.
5280        let weight_data = params[0].data().unwrap();
5281        let bias_data = params[1].data().unwrap();
5282        assert!(weight_data.iter().all(|&x| (x - 1.0).abs() < 1e-7));
5283        assert!(bias_data.iter().all(|&x| x.abs() < 1e-7));
5284    }
5285
5286    #[test]
5287    fn test_batch_norm_2d_no_affine_no_params() {
5288        let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, false).unwrap();
5289        assert_eq!(bn.parameters().len(), 0);
5290        assert_eq!(bn.named_parameters().len(), 0);
5291        assert!(bn.weight.is_none());
5292        assert!(bn.bias.is_none());
5293    }
5294
5295    #[test]
5296    fn test_batch_norm_2d_train_eval_toggle() {
5297        let mut bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5298        assert!(bn.is_training());
5299        bn.eval();
5300        assert!(!bn.is_training());
5301        bn.train();
5302        assert!(bn.is_training());
5303    }
5304
5305    #[test]
5306    fn test_batch_norm_2d_has_grad_fn() {
5307        let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
5308        let input = Tensor::from_storage(
5309            TensorStorage::cpu(vec![1.0f32; 2 * 2 * 3 * 3]),
5310            vec![2, 2, 3, 3],
5311            true,
5312        )
5313        .unwrap();
5314
5315        let output = bn.forward(&input).unwrap();
5316        assert!(output.grad_fn().is_some());
5317        assert_eq!(output.grad_fn().unwrap().name(), "BatchNorm2dBackward");
5318    }
5319
5320    #[test]
5321    fn test_batch_norm_2d_no_grad_fn_in_no_grad_context() {
5322        let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
5323        let input = Tensor::from_storage(
5324            TensorStorage::cpu(vec![1.0f32; 2 * 2 * 3 * 3]),
5325            vec![2, 2, 3, 3],
5326            true,
5327        )
5328        .unwrap();
5329
5330        let output = no_grad(|| bn.forward(&input)).unwrap();
5331        assert!(output.grad_fn().is_none());
5332    }
5333
5334    #[test]
5335    fn test_batch_norm_2d_backward_gradient_check() -> FerrotorchResult<()> {
5336        let h_eps = 1e-7;
5337        let channels = 2;
5338        let b = 2;
5339        let height = 2;
5340        let width = 2;
5341        let spatial = height * width;
5342        let numel = b * channels * spatial;
5343
5344        // Build non-trivial input data.
5345        let input_data: Vec<f64> = (0..numel).map(|i| (i as f64) * 0.3 - 1.0).collect();
5346
5347        let bn = BatchNorm2d::<f64>::new(channels, 1e-5, 0.1, true)?;
5348
5349        let input = leaf(&input_data, &[b, channels, height, width], true);
5350        let output = bn.forward(&input)?;
5351        let out_data = output.data()?.to_vec();
5352        let total: f64 = out_data.iter().sum();
5353
5354        let sum_gf = Arc::new(SumBackwardHelper {
5355            input: output.clone(),
5356        });
5357        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf)?;
5358        loss.backward()?;
5359
5360        let analytic_grad = input.grad().unwrap().unwrap();
5361        let analytic = analytic_grad.data()?.to_vec();
5362
5363        // Numerical gradient with fresh BatchNorm2d instances to avoid
5364        // running-stats side effects. We use eval mode with the same
5365        // batch statistics to keep the function pure.
5366        for i in 0..numel {
5367            // f(x + h)
5368            let mut data_plus = input_data.clone();
5369            data_plus[i] += h_eps;
5370            let inp_plus = leaf(&data_plus, &[b, channels, height, width], false);
5371            let bn_plus = BatchNorm2d::<f64>::new(channels, 1e-5, 0.1, true)?;
5372            let out_plus = no_grad(|| bn_plus.forward(&inp_plus)).unwrap();
5373            let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
5374
5375            // f(x - h)
5376            let mut data_minus = input_data.clone();
5377            data_minus[i] -= h_eps;
5378            let inp_minus = leaf(&data_minus, &[b, channels, height, width], false);
5379            let bn_minus = BatchNorm2d::<f64>::new(channels, 1e-5, 0.1, true)?;
5380            let out_minus = no_grad(|| bn_minus.forward(&inp_minus)).unwrap();
5381            let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
5382
5383            let numerical = (sum_plus - sum_minus) / (2.0 * h_eps);
5384            assert!(
5385                (numerical - analytic[i]).abs() < 1e-4,
5386                "BatchNorm2d grad[{i}]: numerical={numerical}, analytic={}",
5387                analytic[i]
5388            );
5389        }
5390
5391        Ok(())
5392    }
5393
5394    #[test]
5395    fn test_batch_norm_2d_no_affine_forward() {
5396        // Verify that non-affine mode still normalizes correctly.
5397        let channels = 2;
5398        let b = 2;
5399        let h = 2;
5400        let w = 2;
5401        let spatial = h * w;
5402
5403        let mut data = Vec::new();
5404        for bi in 0..b {
5405            for c in 0..channels {
5406                for s in 0..spatial {
5407                    data.push((c as f32) * 5.0 + (bi * spatial + s) as f32);
5408                }
5409            }
5410        }
5411
5412        let input =
5413            Tensor::from_storage(TensorStorage::cpu(data), vec![b, channels, h, w], false).unwrap();
5414
5415        let bn = BatchNorm2d::<f32>::new(channels, 1e-5, 0.1, false).unwrap();
5416        let output = bn.forward(&input).unwrap();
5417        let out_data = output.data().unwrap();
5418
5419        for c in 0..channels {
5420            let mut vals = Vec::new();
5421            for bi in 0..b {
5422                let base = bi * channels * spatial + c * spatial;
5423                for s in 0..spatial {
5424                    vals.push(out_data[base + s]);
5425                }
5426            }
5427            let n = vals.len() as f32;
5428            let mean: f32 = vals.iter().sum::<f32>() / n;
5429            let var: f32 = vals.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / n;
5430
5431            assert!(mean.abs() < 1e-4, "no-affine channel {c}: mean = {mean}");
5432            assert!(
5433                (var - 1.0).abs() < 0.1,
5434                "no-affine channel {c}: var = {var}"
5435            );
5436        }
5437    }
5438
5439    // -----------------------------------------------------------------------
5440    // Send + Sync tests
5441    // -----------------------------------------------------------------------
5442
5443    #[test]
5444    fn test_layer_norm_is_send_sync() {
5445        fn assert_send_sync<T: Send + Sync>() {}
5446        assert_send_sync::<LayerNorm<f32>>();
5447    }
5448
5449    #[test]
5450    fn test_group_norm_is_send_sync() {
5451        fn assert_send_sync<T: Send + Sync>() {}
5452        assert_send_sync::<GroupNorm<f32>>();
5453    }
5454
5455    #[test]
5456    fn test_rms_norm_is_send_sync() {
5457        fn assert_send_sync<T: Send + Sync>() {}
5458        assert_send_sync::<RMSNorm<f32>>();
5459    }
5460
5461    #[test]
5462    fn test_batch_norm_2d_is_send_sync() {
5463        fn assert_send_sync<T: Send + Sync>() {}
5464        assert_send_sync::<BatchNorm2d<f32>>();
5465    }
5466
5467    // -----------------------------------------------------------------------
5468    // Helper backward node for tests
5469    // -----------------------------------------------------------------------
5470
5471    /// Shorthand for the unambiguous one (test-only).
5472    fn one<T: Float>() -> T {
5473        <T as num_traits::One>::one()
5474    }
5475
5476    /// Sum reduction backward for test use: loss = sum(input).
5477    #[derive(Debug)]
5478    struct SumBackwardHelper<T: Float> {
5479        input: Tensor<T>,
5480    }
5481
5482    impl<T: Float> GradFn<T> for SumBackwardHelper<T> {
5483        fn backward(&self, _grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
5484            let ones_data = vec![one::<T>(); self.input.numel()];
5485            let ones = Tensor::from_storage(
5486                TensorStorage::cpu(ones_data),
5487                self.input.shape().to_vec(),
5488                false,
5489            )?;
5490            Ok(vec![Some(ones)])
5491        }
5492
5493        fn inputs(&self) -> Vec<&Tensor<T>> {
5494            vec![&self.input]
5495        }
5496
5497        fn name(&self) -> &'static str {
5498            "SumBackwardHelper"
5499        }
5500    }
5501
5502    // -----------------------------------------------------------------------
5503    // BatchNorm1d tests
5504    // -----------------------------------------------------------------------
5505
5506    #[test]
5507    fn test_batchnorm1d_parameter_shapes() {
5508        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5509        let params = bn.parameters();
5510        assert_eq!(params.len(), 2);
5511        assert_eq!(params[0].shape(), &[3]); // weight
5512        assert_eq!(params[1].shape(), &[3]); // bias
5513    }
5514
5515    #[test]
5516    fn test_batchnorm1d_no_affine() {
5517        let bn = BatchNorm1d::<f32>::new(4, 1e-5, 0.1, false).unwrap();
5518        assert!(bn.parameters().is_empty());
5519    }
5520
5521    #[test]
5522    fn test_batchnorm1d_2d_input() {
5523        // Input: [N=4, C=2]
5524        let bn = BatchNorm1d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
5525        let input = Tensor::from_storage(
5526            TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
5527            vec![4, 2],
5528            false,
5529        )
5530        .unwrap();
5531        let output = bn.forward(&input).unwrap();
5532        assert_eq!(output.shape(), &[4, 2]);
5533    }
5534
5535    #[test]
5536    fn test_batchnorm1d_3d_input() {
5537        // Input: [N=2, C=3, L=4]
5538        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5539        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
5540        let input = Tensor::from_storage(TensorStorage::cpu(data), vec![2, 3, 4], false).unwrap();
5541        let output = bn.forward(&input).unwrap();
5542        assert_eq!(output.shape(), &[2, 3, 4]);
5543    }
5544
5545    #[test]
5546    fn test_batchnorm1d_wrong_dims() {
5547        // 1D, 4D, 5D should fail.
5548        let bn = BatchNorm1d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5549
5550        let input_1d = Tensor::from_storage(
5551            TensorStorage::cpu(vec![1.0f32, 2.0, 3.0, 4.0]),
5552            vec![4],
5553            false,
5554        )
5555        .unwrap();
5556        assert!(bn.forward(&input_1d).is_err());
5557
5558        let input_4d = Tensor::from_storage(
5559            TensorStorage::cpu(vec![0.0f32; 32]),
5560            vec![2, 4, 2, 2],
5561            false,
5562        )
5563        .unwrap();
5564        assert!(bn.forward(&input_4d).is_err());
5565    }
5566
5567    #[test]
5568    fn test_batchnorm1d_zero_features() {
5569        assert!(BatchNorm1d::<f32>::new(0, 1e-5, 0.1, true).is_err());
5570    }
5571
5572    #[test]
5573    fn test_batchnorm1d_channel_mismatch() {
5574        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5575        let input =
5576            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 8]), vec![2, 4], false).unwrap();
5577        assert!(bn.forward(&input).is_err());
5578    }
5579
5580    #[test]
5581    fn test_batchnorm1d_training_normalizes() {
5582        // After training-mode BatchNorm1d (weight=1, bias=0), each channel
5583        // should have approximately zero mean and unit variance.
5584        let channels = 2;
5585        let bn = BatchNorm1d::<f64>::new(channels, 1e-5, 0.1, true).unwrap();
5586
5587        // Input [4, 2]: channel 0 = [1,3,5,7], channel 1 = [2,4,6,8]
5588        let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], false);
5589        let output = bn.forward(&input).unwrap();
5590        let data = output.data().unwrap();
5591
5592        // Check channel 0 mean ~ 0
5593        let ch0: Vec<f64> = (0..4).map(|b| data[b * 2]).collect();
5594        let ch0_mean: f64 = ch0.iter().sum::<f64>() / 4.0;
5595        assert!(
5596            ch0_mean.abs() < 1e-5,
5597            "BatchNorm1d channel 0 mean should be ~0, got {}",
5598            ch0_mean
5599        );
5600
5601        // Check channel 0 variance ~ 1
5602        let ch0_var: f64 = ch0.iter().map(|&x| (x - ch0_mean).powi(2)).sum::<f64>() / 4.0;
5603        assert!(
5604            (ch0_var - 1.0).abs() < 0.1,
5605            "BatchNorm1d channel 0 var should be ~1, got {}",
5606            ch0_var
5607        );
5608    }
5609
5610    #[test]
5611    fn test_batchnorm1d_running_stats_update() {
5612        let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5613        let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], false);
5614        let _ = bn.forward(&input).unwrap();
5615
5616        assert_eq!(bn.num_batches_tracked(), 1);
5617        let rm = bn.running_mean();
5618        let rv = bn.running_var();
5619        // Channel 0 mean = (1+3+5+7)/4 = 4.0
5620        // Channel 1 mean = (2+4+6+8)/4 = 5.0
5621        // running_mean = 0.9 * 0 + 0.1 * batch_mean
5622        assert!(
5623            (rm[0] - 0.1 * 4.0).abs() < 1e-7,
5624            "running_mean[0]: expected {}, got {}",
5625            0.1 * 4.0,
5626            rm[0]
5627        );
5628        assert!(
5629            (rm[1] - 0.1 * 5.0).abs() < 1e-7,
5630            "running_mean[1]: expected {}, got {}",
5631            0.1 * 5.0,
5632            rm[1]
5633        );
5634
5635        // running_var uses Bessel-corrected variance
5636        assert!(rv[0] > 0.0);
5637        assert!(rv[1] > 0.0);
5638    }
5639
5640    #[test]
5641    fn test_batchnorm1d_eval_mode() {
5642        let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5643
5644        // Run training forward to populate running stats.
5645        let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], false);
5646        let _ = bn.forward(&input).unwrap();
5647
5648        // Switch to eval mode.
5649        // We need a mutable reference for eval, so use a workaround.
5650        *bn.training.lock().unwrap() = false;
5651
5652        let eval_out = bn.forward(&input).unwrap();
5653        assert_eq!(eval_out.shape(), &[4, 2]);
5654    }
5655
5656    #[test]
5657    fn test_batchnorm1d_no_affine_normalizes() {
5658        let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, false).unwrap();
5659        let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], false);
5660        let output = bn.forward(&input).unwrap();
5661        assert_eq!(output.shape(), &[4, 2]);
5662    }
5663
5664    #[test]
5665    fn test_batchnorm1d_3d_normalizes() {
5666        // [N=2, C=2, L=3]
5667        let channels = 2;
5668        let bn = BatchNorm1d::<f64>::new(channels, 1e-5, 0.1, true).unwrap();
5669        let data: Vec<f64> = (0..12).map(|i| i as f64).collect();
5670        let input = leaf(&data, &[2, 2, 3], false);
5671        let output = bn.forward(&input).unwrap();
5672        assert_eq!(output.shape(), &[2, 2, 3]);
5673
5674        // Each channel normalized over (N, L) = 6 elements.
5675        let out_data = output.data().unwrap();
5676
5677        // Channel 0 indices in [N, C, L] layout:
5678        // [0, 0, :] = indices 0,1,2; [1, 0, :] = indices 6,7,8
5679        let ch0: Vec<f64> = vec![
5680            out_data[0],
5681            out_data[1],
5682            out_data[2],
5683            out_data[6],
5684            out_data[7],
5685            out_data[8],
5686        ];
5687        let ch0_mean: f64 = ch0.iter().sum::<f64>() / 6.0;
5688        assert!(
5689            ch0_mean.abs() < 1e-5,
5690            "BatchNorm1d 3D channel 0 mean should be ~0, got {}",
5691            ch0_mean
5692        );
5693    }
5694
5695    #[test]
5696    fn test_batchnorm1d_train_eval_toggle() {
5697        let bn = BatchNorm1d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
5698        assert!(bn.is_training());
5699        *bn.training.lock().unwrap() = false;
5700        assert!(!bn.is_training());
5701        *bn.training.lock().unwrap() = true;
5702        assert!(bn.is_training());
5703    }
5704
5705    #[test]
5706    fn test_batchnorm1d_grad_fn_name() {
5707        let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5708        let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], true);
5709        let output = bn.forward(&input).unwrap();
5710        assert_eq!(output.grad_fn().unwrap().name(), "BatchNorm1dBackward");
5711    }
5712
5713    #[test]
5714    fn test_batchnorm1d_backward_grad_shapes() {
5715        use ferrotorch_core::autograd::graph::backward;
5716
5717        let bn = BatchNorm1d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5718        let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[4, 2], true);
5719        let output = bn.forward(&input).unwrap();
5720
5721        // Create a differentiable sum via SumBackwardHelper.
5722        let out_data = output.data().unwrap().to_vec();
5723        let total: f64 = out_data.iter().sum();
5724        let sum_gf = Arc::new(SumBackwardHelper {
5725            input: output.clone(),
5726        });
5727        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5728        backward(&loss).unwrap();
5729
5730        let grad = input.grad().unwrap().unwrap();
5731        assert_eq!(grad.shape(), &[4, 2]);
5732    }
5733
5734    // #1567: a non-affine BatchNorm backward used to return a length-3 grad
5735    // vec while `inputs()` (and the autograd graph) only registered `input`,
5736    // so the engine errored "backward returned 3 gradients but expected 1".
5737    // These CPU regressions prove affine=false backward now runs end-to-end
5738    // (the GPU helper has the identical fix; pinned by the CUDA-gated
5739    // `divergence_bn2d_gpu_nonaffine_train_backward_vs_torch`). The grad-vec
5740    // length must match `inputs()` per
5741    // `aten/src/ATen/native/Normalization.cpp:322-330` (grad_input_mask).
5742    #[test]
5743    fn test_batchnorm2d_nonaffine_backward_runs_grad_input_only() {
5744        use ferrotorch_core::autograd::graph::backward;
5745
5746        let bn = BatchNorm2d::<f64>::new(4, 1e-5, 0.1, false).unwrap();
5747        assert!(bn.weight.is_none() && bn.bias.is_none());
5748        let n = 3 * 4 * 2 * 3;
5749        let data: Vec<f64> = (0..n).map(|k| ((k % 19) as f64) * 0.11 - 1.0).collect();
5750        let input = leaf(&data, &[3, 4, 2, 3], true);
5751        let output = bn.forward(&input).unwrap();
5752
5753        // The grad node declares only `input`, so the engine must accept the
5754        // length-1 grad vec (would error before #1567).
5755        assert_eq!(output.grad_fn().unwrap().inputs().len(), 1);
5756
5757        let total: f64 = output.data().unwrap().iter().sum();
5758        let sum_gf = Arc::new(SumBackwardHelper {
5759            input: output.clone(),
5760        });
5761        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5762        backward(&loss).unwrap();
5763
5764        let grad = input.grad().unwrap().expect("grad_input populated");
5765        assert_eq!(grad.shape(), &[3, 4, 2, 3]);
5766    }
5767
5768    #[test]
5769    fn test_batchnorm1d_nonaffine_backward_runs_grad_input_only() {
5770        use ferrotorch_core::autograd::graph::backward;
5771
5772        let bn = BatchNorm1d::<f64>::new(3, 1e-5, 0.1, false).unwrap();
5773        let n = 4 * 3 * 5;
5774        let data: Vec<f64> = (0..n).map(|k| ((k % 17) as f64) * 0.13 - 0.9).collect();
5775        let input = leaf(&data, &[4, 3, 5], true);
5776        let output = bn.forward(&input).unwrap();
5777        assert_eq!(output.grad_fn().unwrap().inputs().len(), 1);
5778
5779        let total: f64 = output.data().unwrap().iter().sum();
5780        let sum_gf = Arc::new(SumBackwardHelper {
5781            input: output.clone(),
5782        });
5783        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5784        backward(&loss).unwrap();
5785
5786        let grad = input.grad().unwrap().expect("grad_input populated");
5787        assert_eq!(grad.shape(), &[4, 3, 5]);
5788    }
5789
5790    #[test]
5791    fn test_batchnorm3d_nonaffine_backward_runs_grad_input_only() {
5792        use ferrotorch_core::autograd::graph::backward;
5793
5794        let bn = BatchNorm3d::<f64>::new(3, 1e-5, 0.1, false).unwrap();
5795        let n = 2 * 3 * 2 * 2 * 2;
5796        let data: Vec<f64> = (0..n).map(|k| ((k % 13) as f64) * 0.17 - 1.0).collect();
5797        let input = leaf(&data, &[2, 3, 2, 2, 2], true);
5798        let output = bn.forward(&input).unwrap();
5799        assert_eq!(output.grad_fn().unwrap().inputs().len(), 1);
5800
5801        let total: f64 = output.data().unwrap().iter().sum();
5802        let sum_gf = Arc::new(SumBackwardHelper {
5803            input: output.clone(),
5804        });
5805        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5806        backward(&loss).unwrap();
5807
5808        let grad = input.grad().unwrap().expect("grad_input populated");
5809        assert_eq!(grad.shape(), &[2, 3, 2, 2, 2]);
5810    }
5811
5812    #[test]
5813    fn test_batchnorm2d_affine_backward_still_returns_three_grads() {
5814        use ferrotorch_core::autograd::graph::backward;
5815
5816        // Regression guard for the affine=true path: grad vec stays length 3
5817        // and weight/bias grads are populated.
5818        // BatchNorm2d::new sets training=true by default.
5819        let bn = BatchNorm2d::<f64>::new(4, 1e-5, 0.1, true).unwrap();
5820        let n = 3 * 4 * 2 * 3;
5821        let data: Vec<f64> = (0..n).map(|k| ((k % 19) as f64) * 0.11 - 1.0).collect();
5822        let input = leaf(&data, &[3, 4, 2, 3], true);
5823        let output = bn.forward(&input).unwrap();
5824        assert_eq!(output.grad_fn().unwrap().inputs().len(), 3);
5825
5826        let total: f64 = output.data().unwrap().iter().sum();
5827        let sum_gf = Arc::new(SumBackwardHelper {
5828            input: output.clone(),
5829        });
5830        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5831        backward(&loss).unwrap();
5832
5833        assert!(input.grad().unwrap().is_some());
5834        assert!(
5835            bn.weight
5836                .as_ref()
5837                .unwrap()
5838                .tensor()
5839                .grad()
5840                .unwrap()
5841                .is_some()
5842        );
5843        assert!(bn.bias.as_ref().unwrap().tensor().grad().unwrap().is_some());
5844    }
5845
5846    #[test]
5847    fn test_batchnorm1d_backward_numerical() {
5848        use ferrotorch_core::autograd::graph::backward;
5849
5850        // Numerical gradient check for BatchNorm1d.
5851        let channels = 2;
5852        let eps_val = 1e-5;
5853        let input_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
5854        let shape = [4usize, 2];
5855
5856        // Analytic gradient.
5857        let bn = BatchNorm1d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
5858        let input = leaf(&input_data, &shape, true);
5859        let output = bn.forward(&input).unwrap();
5860        let out_data = output.data().unwrap().to_vec();
5861        let total: f64 = out_data.iter().sum();
5862        let sum_gf = Arc::new(SumBackwardHelper {
5863            input: output.clone(),
5864        });
5865        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
5866        backward(&loss).unwrap();
5867        let analytic_grad = input.grad().unwrap().unwrap().data_vec().unwrap();
5868
5869        // Numerical gradient.
5870        let h = 1e-5;
5871        let mut numerical_grad = vec![0.0f64; input_data.len()];
5872        for i in 0..input_data.len() {
5873            let mut data_plus = input_data.clone();
5874            data_plus[i] += h;
5875            let bn_plus = BatchNorm1d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
5876            let input_plus = leaf(&data_plus, &shape, false);
5877            let out_plus = no_grad(|| bn_plus.forward(&input_plus)).unwrap();
5878            let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
5879
5880            let mut data_minus = input_data.clone();
5881            data_minus[i] -= h;
5882            let bn_minus = BatchNorm1d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
5883            let input_minus = leaf(&data_minus, &shape, false);
5884            let out_minus = no_grad(|| bn_minus.forward(&input_minus)).unwrap();
5885            let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
5886
5887            numerical_grad[i] = (sum_plus - sum_minus) / (2.0 * h);
5888        }
5889
5890        for i in 0..input_data.len() {
5891            assert!(
5892                (analytic_grad[i] - numerical_grad[i]).abs() < 1e-3,
5893                "BatchNorm1d grad[{}]: numerical={}, analytic={}",
5894                i,
5895                numerical_grad[i],
5896                analytic_grad[i]
5897            );
5898        }
5899    }
5900
5901    #[test]
5902    fn test_batchnorm1d_empty_batch() {
5903        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5904        let input = Tensor::from_storage(TensorStorage::cpu(vec![]), vec![0, 3], false).unwrap();
5905        let output = bn.forward(&input).unwrap();
5906        assert_eq!(output.shape(), &[0, 3]);
5907        assert_eq!(output.numel(), 0);
5908    }
5909
5910    #[test]
5911    fn test_batchnorm1d_is_send_sync() {
5912        fn assert_send_sync<T: Send + Sync>() {}
5913        assert_send_sync::<BatchNorm1d<f32>>();
5914    }
5915
5916    // -----------------------------------------------------------------------
5917    // BatchNorm3d tests — CL-434
5918    // -----------------------------------------------------------------------
5919
5920    #[test]
5921    fn test_batchnorm3d_output_shape() {
5922        let bn = BatchNorm3d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5923        let input = Tensor::from_storage(
5924            TensorStorage::cpu(vec![1.0f32; 2 * 3 * 2 * 2 * 2]),
5925            vec![2, 3, 2, 2, 2],
5926            false,
5927        )
5928        .unwrap();
5929        let output = bn.forward(&input).unwrap();
5930        assert_eq!(output.shape(), &[2, 3, 2, 2, 2]);
5931    }
5932
5933    #[test]
5934    fn test_batchnorm3d_rejects_non_5d() {
5935        let bn = BatchNorm3d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5936        let input =
5937            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 24]), vec![2, 3, 4], false)
5938                .unwrap();
5939        assert!(bn.forward(&input).is_err());
5940    }
5941
5942    #[test]
5943    fn test_batchnorm3d_channel_mismatch() {
5944        let bn = BatchNorm3d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
5945        let input = Tensor::from_storage(
5946            TensorStorage::cpu(vec![1.0f32; 2 * 4 * 2 * 2 * 2]),
5947            vec![2, 4, 2, 2, 2],
5948            false,
5949        )
5950        .unwrap();
5951        assert!(bn.forward(&input).is_err());
5952    }
5953
5954    #[test]
5955    fn test_batchnorm3d_zero_features_rejected() {
5956        assert!(BatchNorm3d::<f32>::new(0, 1e-5, 0.1, true).is_err());
5957    }
5958
5959    #[test]
5960    fn test_batchnorm3d_training_normalizes() {
5961        // After BatchNorm3d in training mode (weight=1, bias=0),
5962        // each channel should have approximately zero mean.
5963        let channels = 2;
5964        let bn = BatchNorm3d::<f64>::new(channels, 1e-5, 0.1, true).unwrap();
5965        let mut data = Vec::with_capacity(2 * 2 * 2 * 2 * 2);
5966        for i in 0..(2 * 2 * 2 * 2 * 2) {
5967            data.push(i as f64);
5968        }
5969        let input = leaf(&data, &[2, 2, 2, 2, 2], false);
5970        let output = bn.forward(&input).unwrap();
5971        let out_data = output.data().unwrap();
5972
5973        let spatial = 2 * 2 * 2;
5974        let batch = 2;
5975        for c in 0..channels {
5976            let mut sum = 0.0;
5977            for b in 0..batch {
5978                let base = b * channels * spatial + c * spatial;
5979                for s in 0..spatial {
5980                    sum += out_data[base + s];
5981                }
5982            }
5983            let mean = sum / (batch * spatial) as f64;
5984            assert!(mean.abs() < 1e-5, "channel {c} mean = {mean}, expected ~0");
5985        }
5986    }
5987
5988    #[test]
5989    fn test_batchnorm3d_running_stats_updated() {
5990        let bn = BatchNorm3d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
5991        let data: Vec<f64> = (0..32).map(|i| i as f64).collect();
5992        let input = leaf(&data, &[2, 2, 2, 2, 2], false);
5993        let _ = bn.forward(&input).unwrap();
5994
5995        assert_eq!(bn.num_batches_tracked(), 1);
5996        let rm = bn.running_mean();
5997        assert!(
5998            rm[0] != 0.0 || rm[1] != 0.0,
5999            "running mean should be updated"
6000        );
6001    }
6002
6003    #[test]
6004    fn test_batchnorm3d_eval_uses_running_stats() {
6005        let mut bn = BatchNorm3d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
6006        // Train on one batch to populate running stats.
6007        let data: Vec<f64> = (0..32).map(|i| i as f64).collect();
6008        let input = leaf(&data, &[2, 2, 2, 2, 2], false);
6009        let _ = bn.forward(&input).unwrap();
6010
6011        bn.eval();
6012        // Eval forward should use running stats, not batch stats.
6013        let output = bn.forward(&input).unwrap();
6014        assert_eq!(output.shape(), &[2, 2, 2, 2, 2]);
6015    }
6016
6017    #[test]
6018    fn test_batchnorm3d_parameters() {
6019        let bn = BatchNorm3d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6020        let params = bn.parameters();
6021        assert_eq!(params.len(), 2);
6022        assert_eq!(params[0].shape(), &[4]); // weight
6023        assert_eq!(params[1].shape(), &[4]); // bias
6024    }
6025
6026    #[test]
6027    fn test_batchnorm3d_no_affine_no_params() {
6028        let bn = BatchNorm3d::<f32>::new(4, 1e-5, 0.1, false).unwrap();
6029        assert!(bn.parameters().is_empty());
6030    }
6031
6032    #[test]
6033    fn test_batchnorm3d_backward_grad_shapes() {
6034        use ferrotorch_core::autograd::graph::backward;
6035
6036        let bn = BatchNorm3d::<f64>::new(2, 1e-5, 0.1, true).unwrap();
6037        let data: Vec<f64> = (0..32).map(|i| i as f64).collect();
6038        let input = leaf(&data, &[2, 2, 2, 2, 2], true);
6039        let output = bn.forward(&input).unwrap();
6040
6041        let out_data = output.data().unwrap().to_vec();
6042        let total: f64 = out_data.iter().sum();
6043        let sum_gf = Arc::new(SumBackwardHelper {
6044            input: output.clone(),
6045        });
6046        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
6047        backward(&loss).unwrap();
6048
6049        let grad = input.grad().unwrap().unwrap();
6050        assert_eq!(grad.shape(), &[2, 2, 2, 2, 2]);
6051    }
6052
6053    #[test]
6054    fn test_batchnorm3d_backward_numerical() {
6055        use ferrotorch_core::autograd::graph::backward;
6056
6057        let channels = 2;
6058        let eps_val = 1e-5;
6059        let data: Vec<f64> = (0..32).map(|i| i as f64 * 0.1).collect();
6060        let shape = [2usize, 2, 2, 2, 2];
6061
6062        let bn = BatchNorm3d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
6063        let input = leaf(&data, &shape, true);
6064        let output = bn.forward(&input).unwrap();
6065        let out_data = output.data().unwrap().to_vec();
6066        let total: f64 = out_data.iter().sum();
6067        let sum_gf = Arc::new(SumBackwardHelper {
6068            input: output.clone(),
6069        });
6070        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
6071        backward(&loss).unwrap();
6072        let analytic_grad = input.grad().unwrap().unwrap().data_vec().unwrap();
6073
6074        let h = 1e-5;
6075        for i in 0..data.len() {
6076            let mut data_plus = data.clone();
6077            data_plus[i] += h;
6078            let bn_plus = BatchNorm3d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
6079            let input_plus = leaf(&data_plus, &shape, false);
6080            let out_plus = no_grad(|| bn_plus.forward(&input_plus)).unwrap();
6081            let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
6082
6083            let mut data_minus = data.clone();
6084            data_minus[i] -= h;
6085            let bn_minus = BatchNorm3d::<f64>::new(channels, eps_val, 0.1, true).unwrap();
6086            let input_minus = leaf(&data_minus, &shape, false);
6087            let out_minus = no_grad(|| bn_minus.forward(&input_minus)).unwrap();
6088            let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
6089
6090            let numerical = (sum_plus - sum_minus) / (2.0 * h);
6091            assert!(
6092                (analytic_grad[i] - numerical).abs() < 1e-3,
6093                "BatchNorm3d grad[{}]: numerical={}, analytic={}",
6094                i,
6095                numerical,
6096                analytic_grad[i]
6097            );
6098        }
6099    }
6100
6101    #[test]
6102    fn test_batchnorm3d_is_send_sync() {
6103        fn assert_send_sync<T: Send + Sync>() {}
6104        assert_send_sync::<BatchNorm3d<f32>>();
6105    }
6106
6107    // -----------------------------------------------------------------------
6108    // LocalResponseNorm tests — CL-435
6109    // -----------------------------------------------------------------------
6110
6111    #[test]
6112    fn test_lrn_output_shape() {
6113        let lrn = LocalResponseNorm::new(5, 1e-4, 0.75, 1.0).unwrap();
6114        let input = Tensor::<f32>::from_storage(
6115            TensorStorage::cpu(vec![1.0f32; 2 * 4 * 3 * 3]),
6116            vec![2, 4, 3, 3],
6117            false,
6118        )
6119        .unwrap();
6120        let output = Module::<f32>::forward(&lrn, &input).unwrap();
6121        assert_eq!(output.shape(), &[2, 4, 3, 3]);
6122    }
6123
6124    #[test]
6125    fn test_lrn_3d_input() {
6126        let lrn = LocalResponseNorm::new(3, 1e-4, 0.75, 1.0).unwrap();
6127        let input = Tensor::<f32>::from_storage(
6128            TensorStorage::cpu(vec![1.0f32; 2 * 4 * 8]),
6129            vec![2, 4, 8],
6130            false,
6131        )
6132        .unwrap();
6133        let output = Module::<f32>::forward(&lrn, &input).unwrap();
6134        assert_eq!(output.shape(), &[2, 4, 8]);
6135    }
6136
6137    #[test]
6138    fn test_lrn_rejects_2d() {
6139        let lrn = LocalResponseNorm::new(3, 1e-4, 0.75, 1.0).unwrap();
6140        let input =
6141            Tensor::<f32>::from_storage(TensorStorage::cpu(vec![1.0f32; 8]), vec![2, 4], false)
6142                .unwrap();
6143        assert!(Module::<f32>::forward(&lrn, &input).is_err());
6144    }
6145
6146    #[test]
6147    fn test_lrn_zero_size_rejected() {
6148        assert!(LocalResponseNorm::new(0, 1e-4, 0.75, 1.0).is_err());
6149    }
6150
6151    #[test]
6152    fn test_lrn_default_params() {
6153        let lrn = LocalResponseNorm::default_params(5).unwrap();
6154        assert_eq!(lrn.size, 5);
6155        assert!((lrn.alpha - 1e-4).abs() < 1e-10);
6156        assert!((lrn.beta - 0.75).abs() < 1e-10);
6157        assert!((lrn.k - 1.0).abs() < 1e-10);
6158    }
6159
6160    #[test]
6161    fn test_lrn_no_parameters() {
6162        let lrn = LocalResponseNorm::new(5, 1e-4, 0.75, 1.0).unwrap();
6163        assert!(Module::<f32>::parameters(&lrn).is_empty());
6164    }
6165
6166    #[test]
6167    fn test_lrn_divides_by_norm() {
6168        // With large alpha and k=0 (edge case), output should be significantly
6169        // attenuated compared to input.
6170        let lrn = LocalResponseNorm::new(3, 10.0, 1.0, 1.0).unwrap();
6171        let data: Vec<f32> = vec![1.0; 3 * 2];
6172        let input =
6173            Tensor::<f32>::from_storage(TensorStorage::cpu(data), vec![1, 3, 2], false).unwrap();
6174        let output = Module::<f32>::forward(&lrn, &input).unwrap();
6175        let out_data = output.data().unwrap();
6176
6177        // All outputs should be smaller than 1.0 since normalization divides.
6178        for &v in out_data.iter() {
6179            assert!(
6180                v < 1.0 && v > 0.0,
6181                "LRN output {v} should be attenuated (0 < v < 1)"
6182            );
6183        }
6184    }
6185
6186    #[test]
6187    fn test_lrn_backward_numerical() {
6188        let lrn = LocalResponseNorm::new(3, 1e-4, 0.75, 1.0).unwrap();
6189        let input_data: Vec<f64> = vec![
6190            1.0, -0.5, 2.0, 0.3, 0.7, -1.2, 0.4, 1.5, -0.3, 0.8, 1.1, -0.7,
6191        ];
6192        let shape = vec![1usize, 3, 4];
6193
6194        let input = leaf(&input_data, &shape, true);
6195        let output = Module::<f64>::forward(&lrn, &input).unwrap();
6196        let out_data = output.data().unwrap().to_vec();
6197        let total: f64 = out_data.iter().sum();
6198
6199        let sum_gf = Arc::new(SumBackwardHelper {
6200            input: output.clone(),
6201        });
6202        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
6203        loss.backward().unwrap();
6204
6205        let analytic_grad = input.grad().unwrap().unwrap();
6206        let analytic = analytic_grad.data().unwrap().to_vec();
6207
6208        let h = 1e-6;
6209        for i in 0..input_data.len() {
6210            let mut data_plus = input_data.clone();
6211            data_plus[i] += h;
6212            let inp_plus = leaf(&data_plus, &shape, false);
6213            let out_plus = no_grad(|| Module::<f64>::forward(&lrn, &inp_plus)).unwrap();
6214            let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
6215
6216            let mut data_minus = input_data.clone();
6217            data_minus[i] -= h;
6218            let inp_minus = leaf(&data_minus, &shape, false);
6219            let out_minus = no_grad(|| Module::<f64>::forward(&lrn, &inp_minus)).unwrap();
6220            let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
6221
6222            let numerical = (sum_plus - sum_minus) / (2.0 * h);
6223            assert!(
6224                (numerical - analytic[i]).abs() < 1e-4,
6225                "LRN grad[{i}]: numerical={numerical}, analytic={}",
6226                analytic[i]
6227            );
6228        }
6229    }
6230
6231    // -----------------------------------------------------------------------
6232    // InstanceNorm tests — CL-315
6233    // -----------------------------------------------------------------------
6234
6235    #[test]
6236    fn test_instancenorm1d_output_shape() {
6237        let norm = InstanceNorm1d::<f32>::new(3, 1e-5, true).unwrap();
6238        // Input [B=2, C=3, L=8]
6239        let input =
6240            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 48]), vec![2, 3, 8], false)
6241                .unwrap();
6242        let out = norm.forward(&input).unwrap();
6243        assert_eq!(out.shape(), &[2, 3, 8]);
6244    }
6245
6246    #[test]
6247    fn test_instancenorm1d_rejects_wrong_ndim() {
6248        let norm = InstanceNorm1d::<f32>::new(3, 1e-5, true).unwrap();
6249        // 4D input should fail.
6250        let input = Tensor::from_storage(
6251            TensorStorage::cpu(vec![1.0f32; 48]),
6252            vec![2, 3, 4, 2],
6253            false,
6254        )
6255        .unwrap();
6256        assert!(norm.forward(&input).is_err());
6257    }
6258
6259    #[test]
6260    fn test_instancenorm2d_normalizes_per_instance_channel() {
6261        // Each (b, c) spatial plane should have ~zero mean, ~unit var after norm.
6262        let norm = InstanceNorm2d::<f32>::new(2, 1e-5, true).unwrap();
6263        // [B=1, C=2, H=2, W=2]
6264        let data: Vec<f32> = vec![
6265            1.0, 2.0, 3.0, 4.0, // channel 0
6266            5.0, 6.0, 7.0, 8.0, // channel 1
6267        ];
6268        let input =
6269            Tensor::from_storage(TensorStorage::cpu(data), vec![1, 2, 2, 2], false).unwrap();
6270        let out = norm.forward(&input).unwrap();
6271        let d = out.data().unwrap();
6272
6273        // Check each channel independently.
6274        for c in 0..2 {
6275            let start = c * 4;
6276            let end = start + 4;
6277            let slice = &d[start..end];
6278            let mean: f32 = slice.iter().sum::<f32>() / 4.0;
6279            let var: f32 = slice.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / 4.0;
6280            assert!(mean.abs() < 1e-5, "channel {c} mean = {mean}, expected ~0");
6281            assert!(
6282                (var - 1.0).abs() < 0.1,
6283                "channel {c} var = {var}, expected ~1"
6284            );
6285        }
6286    }
6287
6288    #[test]
6289    fn test_instancenorm2d_rejects_wrong_ndim() {
6290        let norm = InstanceNorm2d::<f32>::new(3, 1e-5, true).unwrap();
6291        // 3D input should fail.
6292        let input =
6293            Tensor::from_storage(TensorStorage::cpu(vec![1.0f32; 24]), vec![2, 3, 4], false)
6294                .unwrap();
6295        assert!(norm.forward(&input).is_err());
6296    }
6297
6298    #[test]
6299    fn test_instancenorm3d_output_shape() {
6300        let norm = InstanceNorm3d::<f32>::new(2, 1e-5, false).unwrap();
6301        // [B=1, C=2, D=2, H=2, W=2]
6302        let input = Tensor::from_storage(
6303            TensorStorage::cpu(vec![1.0f32; 16]),
6304            vec![1, 2, 2, 2, 2],
6305            false,
6306        )
6307        .unwrap();
6308        let out = norm.forward(&input).unwrap();
6309        assert_eq!(out.shape(), &[1, 2, 2, 2, 2]);
6310    }
6311
6312    #[test]
6313    fn test_instancenorm2d_no_affine_no_params() {
6314        let norm = InstanceNorm2d::<f32>::new(4, 1e-5, false).unwrap();
6315        assert!(Module::<f32>::parameters(&norm).is_empty());
6316    }
6317
6318    #[test]
6319    fn test_instancenorm2d_has_affine_params() {
6320        let norm = InstanceNorm2d::<f32>::new(4, 1e-5, true).unwrap();
6321        let params = Module::<f32>::parameters(&norm);
6322        assert_eq!(params.len(), 2);
6323        assert_eq!(params[0].shape(), &[4]); // weight
6324        assert_eq!(params[1].shape(), &[4]); // bias
6325    }
6326
6327    #[test]
6328    fn test_instancenorm2d_backward_gradient_check() {
6329        let h = 1e-7;
6330        let num_features = 2;
6331        // Input [1, 2, 2, 2]
6332        let input_data: Vec<f64> = vec![1.0, -0.5, 2.0, 0.3, 0.7, -1.2, 0.4, 1.5];
6333        let shape = vec![1usize, 2, 2, 2];
6334
6335        let norm = InstanceNorm2d::<f64>::new(num_features, 1e-5, true).unwrap();
6336
6337        // Forward + backward.
6338        let input = leaf(&input_data, &shape, true);
6339        let output = norm.forward(&input).unwrap();
6340        let out_data = output.data().unwrap().to_vec();
6341        let total: f64 = out_data.iter().sum();
6342
6343        let sum_gf = Arc::new(SumBackwardHelper {
6344            input: output.clone(),
6345        });
6346        let loss = Tensor::from_operation(TensorStorage::cpu(vec![total]), vec![], sum_gf).unwrap();
6347        loss.backward().unwrap();
6348
6349        let analytic_grad = input.grad().unwrap().unwrap();
6350        let analytic = analytic_grad.data().unwrap().to_vec();
6351
6352        // Numerical gradient.
6353        for i in 0..input_data.len() {
6354            let mut data_plus = input_data.clone();
6355            data_plus[i] += h;
6356            let inp_plus = leaf(&data_plus, &shape, false);
6357            let out_plus = no_grad(|| norm.forward(&inp_plus)).unwrap();
6358            let sum_plus: f64 = out_plus.data().unwrap().iter().sum();
6359
6360            let mut data_minus = input_data.clone();
6361            data_minus[i] -= h;
6362            let inp_minus = leaf(&data_minus, &shape, false);
6363            let out_minus = no_grad(|| norm.forward(&inp_minus)).unwrap();
6364            let sum_minus: f64 = out_minus.data().unwrap().iter().sum();
6365
6366            let numerical = (sum_plus - sum_minus) / (2.0 * h);
6367            assert!(
6368                (numerical - analytic[i]).abs() < 1e-4,
6369                "InstanceNorm2d grad[{i}]: numerical={numerical}, analytic={}",
6370                analytic[i]
6371            );
6372        }
6373    }
6374
6375    #[test]
6376    fn test_instancenorm_zero_features_rejected() {
6377        assert!(InstanceNorm1d::<f32>::new(0, 1e-5, true).is_err());
6378        assert!(InstanceNorm2d::<f32>::new(0, 1e-5, true).is_err());
6379        assert!(InstanceNorm3d::<f32>::new(0, 1e-5, true).is_err());
6380    }
6381
6382    // -----------------------------------------------------------------------
6383    // BatchNorm running-stat setters — Phase 2 of value-parity pipeline (#984)
6384    //
6385    // These tests cover the surface added in #984:
6386    //  • set_running_mean / set_running_var / set_num_batches_tracked on
6387    //    BatchNorm1d / BatchNorm2d / BatchNorm3d
6388    //  • the as_any() downcast hook on Module that lets a state-dict
6389    //    loader walking named_modules() route buffer keys to the right BN
6390    //
6391    // Coverage map:
6392    //  - round_trip_*: setter + getter equality across all three BN types
6393    //  - validate_*: each rejection branch (length / non-finite / negative)
6394    //  - flow_through_eval_forward_*: non-default running stats actually
6395    //      shape the eval-mode forward output (the bit Phase 1A's workaround
6396    //      could not exercise)
6397    //  - as_any_downcast_*: the hook returns a Some that downcasts back
6398    //      to the concrete BN type
6399    // -----------------------------------------------------------------------
6400
6401    /// Round-trip: set_running_mean → running_mean returns the same values.
6402    #[test]
6403    fn bn2d_set_running_mean_round_trip() {
6404        let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6405        let v: [f32; 4] = [0.5, -1.25, 2.0, 0.0];
6406        bn.set_running_mean(&v).unwrap();
6407        let got = bn.running_mean();
6408        assert_eq!(got.len(), 4);
6409        for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6410            assert!(
6411                (g - *e as f64).abs() < 1e-7,
6412                "channel {i}: got={g}, expected={e}"
6413            );
6414        }
6415    }
6416
6417    #[test]
6418    fn bn2d_set_running_var_round_trip() {
6419        let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6420        let v: [f32; 4] = [1.5, 0.25, 2.0, 0.0];
6421        bn.set_running_var(&v).unwrap();
6422        let got = bn.running_var();
6423        for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6424            assert!(
6425                (g - *e as f64).abs() < 1e-7,
6426                "channel {i}: got={g}, expected={e}"
6427            );
6428        }
6429    }
6430
6431    #[test]
6432    fn bn2d_set_num_batches_tracked_round_trip() {
6433        let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6434        bn.set_num_batches_tracked(42).unwrap();
6435        assert_eq!(bn.num_batches_tracked(), 42);
6436        // Overwrite (idempotent).
6437        bn.set_num_batches_tracked(0).unwrap();
6438        assert_eq!(bn.num_batches_tracked(), 0);
6439    }
6440
6441    #[test]
6442    fn bn1d_set_running_mean_round_trip() {
6443        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6444        let v: [f32; 3] = [-0.1, 0.2, 0.3];
6445        bn.set_running_mean(&v).unwrap();
6446        let got = bn.running_mean();
6447        for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6448            assert!(
6449                (g - *e as f64).abs() < 1e-7,
6450                "channel {i}: got={g}, expected={e}"
6451            );
6452        }
6453    }
6454
6455    #[test]
6456    fn bn1d_set_running_var_round_trip() {
6457        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6458        let v: [f32; 3] = [0.5, 1.0, 4.0];
6459        bn.set_running_var(&v).unwrap();
6460        let got = bn.running_var();
6461        for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6462            assert!(
6463                (g - *e as f64).abs() < 1e-7,
6464                "channel {i}: got={g}, expected={e}"
6465            );
6466        }
6467    }
6468
6469    #[test]
6470    fn bn1d_set_num_batches_tracked_round_trip() {
6471        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6472        bn.set_num_batches_tracked(7).unwrap();
6473        assert_eq!(bn.num_batches_tracked(), 7);
6474    }
6475
6476    #[test]
6477    fn bn3d_set_running_mean_round_trip() {
6478        let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6479        let v: [f32; 2] = [-2.0, 3.0];
6480        bn.set_running_mean(&v).unwrap();
6481        let got = bn.running_mean();
6482        for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6483            assert!(
6484                (g - *e as f64).abs() < 1e-7,
6485                "channel {i}: got={g}, expected={e}"
6486            );
6487        }
6488    }
6489
6490    #[test]
6491    fn bn3d_set_running_var_round_trip() {
6492        let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6493        let v: [f32; 2] = [1.0, 0.5];
6494        bn.set_running_var(&v).unwrap();
6495        let got = bn.running_var();
6496        for (i, (g, e)) in got.iter().zip(v.iter()).enumerate() {
6497            assert!(
6498                (g - *e as f64).abs() < 1e-7,
6499                "channel {i}: got={g}, expected={e}"
6500            );
6501        }
6502    }
6503
6504    #[test]
6505    fn bn3d_set_num_batches_tracked_round_trip() {
6506        let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6507        bn.set_num_batches_tracked(11).unwrap();
6508        assert_eq!(bn.num_batches_tracked(), 11);
6509    }
6510
6511    // ── Validation rejections ────────────────────────────────────────────
6512
6513    #[test]
6514    fn bn2d_set_running_mean_rejects_wrong_length() {
6515        let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6516        let too_short: [f32; 3] = [0.0, 0.0, 0.0];
6517        let err = bn
6518            .set_running_mean(&too_short)
6519            .expect_err("wrong length should error");
6520        match err {
6521            FerrotorchError::ShapeMismatch { message } => {
6522                assert!(message.contains("BatchNorm2d::set_running_mean"));
6523                assert!(message.contains("num_features=4"));
6524            }
6525            other => panic!("expected ShapeMismatch, got {other:?}"),
6526        }
6527        let too_long: [f32; 5] = [0.0; 5];
6528        bn.set_running_mean(&too_long)
6529            .expect_err("wrong length should error");
6530    }
6531
6532    #[test]
6533    fn bn2d_set_running_mean_rejects_non_finite() {
6534        let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6535        let nan_val: [f32; 2] = [0.0, f32::NAN];
6536        match bn.set_running_mean(&nan_val).expect_err("nan should error") {
6537            FerrotorchError::InvalidArgument { message } => {
6538                assert!(message.contains("non-finite"));
6539                assert!(message.contains("index 1"));
6540            }
6541            other => panic!("expected InvalidArgument, got {other:?}"),
6542        }
6543        let inf_val: [f32; 2] = [f32::INFINITY, 0.0];
6544        bn.set_running_mean(&inf_val).expect_err("inf should error");
6545    }
6546
6547    #[test]
6548    fn bn2d_set_running_var_rejects_negative() {
6549        let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6550        let v: [f32; 2] = [1.0, -0.5];
6551        match bn
6552            .set_running_var(&v)
6553            .expect_err("negative variance should error")
6554        {
6555            FerrotorchError::InvalidArgument { message } => {
6556                assert!(message.contains("negative"));
6557                assert!(message.contains("index 1"));
6558            }
6559            other => panic!("expected InvalidArgument, got {other:?}"),
6560        }
6561    }
6562
6563    #[test]
6564    fn bn2d_set_running_var_rejects_non_finite() {
6565        let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6566        let v: [f32; 2] = [1.0, f32::NAN];
6567        match bn
6568            .set_running_var(&v)
6569            .expect_err("nan variance should error")
6570        {
6571            FerrotorchError::InvalidArgument { message } => {
6572                assert!(message.contains("non-finite"));
6573            }
6574            other => panic!("expected InvalidArgument, got {other:?}"),
6575        }
6576    }
6577
6578    #[test]
6579    fn bn1d_set_running_var_rejects_negative_and_wrong_length() {
6580        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6581        let bad_neg: [f32; 3] = [1.0, -1.0, 1.0];
6582        bn.set_running_var(&bad_neg)
6583            .expect_err("negative should error");
6584        let bad_len: [f32; 2] = [1.0, 1.0];
6585        bn.set_running_var(&bad_len)
6586            .expect_err("wrong length should error");
6587    }
6588
6589    #[test]
6590    fn bn3d_set_running_mean_rejects_non_finite_and_wrong_length() {
6591        let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6592        let bad_nan: [f32; 2] = [f32::NAN, 0.0];
6593        bn.set_running_mean(&bad_nan).expect_err("nan should error");
6594        let bad_len: [f32; 1] = [0.0];
6595        bn.set_running_mean(&bad_len)
6596            .expect_err("wrong length should error");
6597    }
6598
6599    // ── Flow-through eval forward ───────────────────────────────────────
6600    //
6601    // Construct a BatchNorm2d, set non-default running stats via the
6602    // setters, run eval-mode forward on a known input, and compare to a
6603    // hand-computed expected value derived from the SET running stats.
6604    //
6605    // This is the key Phase 2 evidence: the new setters actually cause
6606    // forward-time normalization to use the loaded values, not BN's
6607    // construction-time defaults (mean=0, var=1).
6608
6609    #[test]
6610    fn bn2d_set_running_stats_flow_through_eval_forward() {
6611        // 1 batch, 2 channels, 1×1 spatial — keeps the math byte-trivial.
6612        let mut bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6613        bn.set_running_mean(&[3.0_f32, -2.0]).unwrap();
6614        bn.set_running_var(&[4.0_f32, 0.25]).unwrap();
6615        bn.eval();
6616
6617        // Affine defaults are weight=1, bias=0, so output = (x - mean) / sqrt(var + eps).
6618        // Channel 0: (5.0 - 3.0) / sqrt(4 + 1e-5) = 2 / 2.000002... ≈ 0.9999975
6619        // Channel 1: (1.0 - (-2.0)) / sqrt(0.25 + 1e-5) = 3 / 0.5000099... ≈ 5.99988
6620        let input = Tensor::from_storage(
6621            TensorStorage::cpu(vec![5.0_f32, 1.0]),
6622            vec![1, 2, 1, 1],
6623            false,
6624        )
6625        .unwrap();
6626        let out = bn.forward(&input).unwrap();
6627        let data = out.data_vec().unwrap();
6628        assert_eq!(data.len(), 2);
6629
6630        let expected_0 = 2.0_f32 / (4.0_f32 + 1e-5).sqrt();
6631        let expected_1 = 3.0_f32 / (0.25_f32 + 1e-5).sqrt();
6632
6633        assert!(
6634            (data[0] - expected_0).abs() < 1e-5,
6635            "channel 0: got {}, expected {}",
6636            data[0],
6637            expected_0
6638        );
6639        assert!(
6640            (data[1] - expected_1).abs() < 1e-5,
6641            "channel 1: got {}, expected {}",
6642            data[1],
6643            expected_1
6644        );
6645
6646        // Sanity: the construction-default expectation (mean=0, var=1) gives
6647        //   out_0 = 5 / sqrt(1+eps) ≈ 5.0   (vs our 0.9999..., factor of 5 off)
6648        //   out_1 = 1 / sqrt(1+eps) ≈ 1.0   (vs our 5.999..., factor of 6 off)
6649        // — i.e. without the setters, the values would be very different.
6650        // This is the "set stats actually flow through forward" check.
6651        assert!(
6652            (data[0] - 5.0_f32).abs() > 1.0,
6653            "data[0]={} too close to default-stats output (5.0); setter is not flowing through",
6654            data[0]
6655        );
6656        assert!(
6657            (data[1] - 1.0_f32).abs() > 1.0,
6658            "data[1]={} too close to default-stats output (1.0); setter is not flowing through",
6659            data[1]
6660        );
6661    }
6662
6663    #[test]
6664    fn bn1d_set_running_stats_flow_through_eval_forward() {
6665        // 1 batch, 2 channels, 2D input [N, C].
6666        let mut bn = BatchNorm1d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6667        bn.set_running_mean(&[1.0_f32, 0.0]).unwrap();
6668        bn.set_running_var(&[9.0_f32, 4.0]).unwrap();
6669        bn.eval();
6670
6671        let input = Tensor::from_storage(TensorStorage::cpu(vec![4.0_f32, 6.0]), vec![1, 2], false)
6672            .unwrap();
6673        let out = bn.forward(&input).unwrap();
6674        let data = out.data_vec().unwrap();
6675
6676        // (4 - 1) / sqrt(9 + 1e-5) ≈ 1.0
6677        // (6 - 0) / sqrt(4 + 1e-5) ≈ 3.0
6678        let expected_0 = 3.0_f32 / (9.0_f32 + 1e-5).sqrt();
6679        let expected_1 = 6.0_f32 / (4.0_f32 + 1e-5).sqrt();
6680        assert!(
6681            (data[0] - expected_0).abs() < 1e-5,
6682            "BN1d ch0: got {}, expected {}",
6683            data[0],
6684            expected_0
6685        );
6686        assert!(
6687            (data[1] - expected_1).abs() < 1e-5,
6688            "BN1d ch1: got {}, expected {}",
6689            data[1],
6690            expected_1
6691        );
6692    }
6693
6694    #[test]
6695    fn bn3d_set_running_stats_flow_through_eval_forward() {
6696        // 1 batch, 2 channels, 1×1×1 spatial.
6697        let mut bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6698        bn.set_running_mean(&[2.0_f32, -1.0]).unwrap();
6699        bn.set_running_var(&[1.0_f32, 16.0]).unwrap();
6700        bn.eval();
6701
6702        let input = Tensor::from_storage(
6703            TensorStorage::cpu(vec![3.0_f32, 7.0]),
6704            vec![1, 2, 1, 1, 1],
6705            false,
6706        )
6707        .unwrap();
6708        let out = bn.forward(&input).unwrap();
6709        let data = out.data_vec().unwrap();
6710
6711        let expected_0 = 1.0_f32 / (1.0_f32 + 1e-5).sqrt();
6712        let expected_1 = 8.0_f32 / (16.0_f32 + 1e-5).sqrt();
6713        assert!(
6714            (data[0] - expected_0).abs() < 1e-5,
6715            "BN3d ch0: got {}, expected {}",
6716            data[0],
6717            expected_0
6718        );
6719        assert!(
6720            (data[1] - expected_1).abs() < 1e-5,
6721            "BN3d ch1: got {}, expected {}",
6722            data[1],
6723            expected_1
6724        );
6725    }
6726
6727    // ── Module::as_any downcast hook ─────────────────────────────────────
6728
6729    #[test]
6730    fn bn2d_as_any_downcasts_to_concrete_type() {
6731        let bn = BatchNorm2d::<f32>::new(4, 1e-5, 0.1, true).unwrap();
6732        let dyn_module: &dyn Module<f32> = &bn;
6733        let any = dyn_module
6734            .as_any()
6735            .expect("BatchNorm2d::as_any returns Some");
6736        let concrete = any
6737            .downcast_ref::<BatchNorm2d<f32>>()
6738            .expect("any must downcast to BatchNorm2d<f32>");
6739        assert_eq!(concrete.num_features, 4);
6740        // Wrong-type downcast must fail.
6741        assert!(any.downcast_ref::<BatchNorm1d<f32>>().is_none());
6742        assert!(any.downcast_ref::<BatchNorm3d<f32>>().is_none());
6743    }
6744
6745    #[test]
6746    fn bn1d_as_any_downcasts_to_concrete_type() {
6747        let bn = BatchNorm1d::<f32>::new(3, 1e-5, 0.1, true).unwrap();
6748        let dyn_module: &dyn Module<f32> = &bn;
6749        let any = dyn_module
6750            .as_any()
6751            .expect("BatchNorm1d::as_any returns Some");
6752        let concrete = any
6753            .downcast_ref::<BatchNorm1d<f32>>()
6754            .expect("any must downcast to BatchNorm1d<f32>");
6755        assert_eq!(concrete.num_features, 3);
6756    }
6757
6758    #[test]
6759    fn bn3d_as_any_downcasts_to_concrete_type() {
6760        let bn = BatchNorm3d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6761        let dyn_module: &dyn Module<f32> = &bn;
6762        let any = dyn_module
6763            .as_any()
6764            .expect("BatchNorm3d::as_any returns Some");
6765        let concrete = any
6766            .downcast_ref::<BatchNorm3d<f32>>()
6767            .expect("any must downcast to BatchNorm3d<f32>");
6768        assert_eq!(concrete.num_features, 2);
6769    }
6770
6771    /// Sanity: the default `as_any` on Module returns None for a non-BN
6772    /// module (LayerNorm here), so a generic loader walking
6773    /// named_modules() correctly skips non-BN modules.
6774    #[test]
6775    fn non_bn_module_as_any_returns_none() {
6776        let ln = LayerNorm::<f32>::new(vec![4], 1e-5, true).unwrap();
6777        let dyn_module: &dyn Module<f32> = &ln;
6778        assert!(
6779            dyn_module.as_any().is_none(),
6780            "non-BN modules must not opt into the as_any downcast hook"
6781        );
6782    }
6783
6784    /// Setting running stats does NOT bump num_batches_tracked — the two
6785    /// counters are independent. Tracked batches only advances during
6786    /// training-mode forward.
6787    #[test]
6788    fn bn2d_set_running_mean_does_not_touch_nbt() {
6789        let bn = BatchNorm2d::<f32>::new(2, 1e-5, 0.1, true).unwrap();
6790        assert_eq!(bn.num_batches_tracked(), 0);
6791        bn.set_running_mean(&[0.5_f32, 0.5]).unwrap();
6792        assert_eq!(bn.num_batches_tracked(), 0);
6793        bn.set_running_var(&[1.0_f32, 1.0]).unwrap();
6794        assert_eq!(bn.num_batches_tracked(), 0);
6795    }
6796
6797    /// #1357: `GroupNorm::forward` on a CUDA-resident input must route through
6798    /// `GpuBackend::group_norm_f32` and produce values matching the CPU path
6799    /// within f32 tolerance.
6800    ///
6801    /// Gated `#[ignore]` because it needs real CUDA hardware (the build host
6802    /// has none); it documents the expected GPU↔CPU parity for a future
6803    /// CUDA-host run. Tracking #1356/#1357.
6804    #[test]
6805    #[cfg(feature = "cuda")]
6806    #[ignore = "needs CUDA hardware; tracking #1356/#1357"]
6807    fn group_norm_forward_gpu_matches_cpu() {
6808        use crate::module::Module as _;
6809        use ferrotorch_core::Device;
6810        use ferrotorch_gpu::init_cuda_backend;
6811        init_cuda_backend().expect("CUDA init failed");
6812
6813        // [B=2, C=8, H=3, W=4] with G=4 groups.
6814        let b = 2;
6815        let c = 8;
6816        let h = 3;
6817        let w = 4;
6818        let n = b * c * h * w;
6819        let data: Vec<f32> = (0..n).map(|k| ((k % 17) as f32) * 0.13 - 1.1).collect();
6820
6821        // Non-trivial affine so the GPU path exercises weight/bias, not just
6822        // the identity gamma=1 / beta=0 defaults.
6823        let gamma: Vec<f32> = (0..c).map(|k| 1.0 + 0.05 * (k as f32)).collect();
6824        let beta: Vec<f32> = (0..c).map(|k| -0.1 + 0.02 * (k as f32)).collect();
6825        let mut gn = GroupNorm::<f32>::new(4, c, 1e-5, true).unwrap();
6826        gn.weight
6827            .set_data(Tensor::from_storage(TensorStorage::cpu(gamma), vec![c], false).unwrap());
6828        gn.bias
6829            .set_data(Tensor::from_storage(TensorStorage::cpu(beta), vec![c], false).unwrap());
6830
6831        let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![b, c, h, w], false)
6832            .unwrap();
6833        let y_cpu = gn.forward(&x_cpu).unwrap();
6834        let cpu_vals = y_cpu.data().unwrap().to_vec();
6835
6836        // Move the whole module (weight + bias) to CUDA, then run on a
6837        // CUDA-resident input so `GroupNorm::forward` takes the GPU fast path.
6838        gn.to_device(Device::Cuda(0)).unwrap();
6839        let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
6840        let y_gpu = gn.forward(&x_gpu).unwrap();
6841        assert!(y_gpu.is_cuda(), "GroupNorm GPU output must stay on CUDA");
6842        let gpu_vals = y_gpu.data_vec().unwrap();
6843
6844        assert_eq!(gpu_vals.len(), cpu_vals.len());
6845        let mut max_abs = 0.0f32;
6846        for (g, c) in gpu_vals.iter().zip(cpu_vals.iter()) {
6847            max_abs = max_abs.max((g - c).abs());
6848        }
6849        assert!(max_abs < 1e-4, "GroupNorm GPU vs CPU max|Δ| = {max_abs}");
6850    }
6851
6852    /// #1449: `BatchNorm2d::forward` in **eval** mode on a CUDA-resident input
6853    /// must route through `GpuBackend::batch_norm_f32` and match the CPU path
6854    /// (which is itself PyTorch-parity-verified by the conformance suite) to
6855    /// f32 tolerance. The host has CUDA, so this runs live (not `#[ignore]`).
6856    #[test]
6857    #[cfg(feature = "cuda")]
6858    fn batch_norm2d_eval_forward_gpu_matches_cpu() {
6859        use crate::module::Module as _;
6860        use ferrotorch_core::Device;
6861        use ferrotorch_gpu::init_cuda_backend;
6862        if init_cuda_backend().is_err() {
6863            return;
6864        }
6865
6866        let (b, c, h, w) = (2usize, 6usize, 4usize, 5usize);
6867        let n = b * c * h * w;
6868        let data: Vec<f32> = (0..n).map(|k| ((k % 19) as f32) * 0.11 - 1.0).collect();
6869        let gamma: Vec<f32> = (0..c).map(|k| 1.0 + 0.07 * (k as f32)).collect();
6870        let beta: Vec<f32> = (0..c).map(|k| -0.2 + 0.03 * (k as f32)).collect();
6871        let rmean: Vec<f32> = (0..c).map(|k| 0.05 * (k as f32) - 0.1).collect();
6872        let rvar: Vec<f32> = (0..c).map(|k| 0.8 + 0.05 * (k as f32)).collect();
6873
6874        let make = || {
6875            let mut bn = BatchNorm2d::<f32>::new(c, 1e-5, 0.1, true).unwrap();
6876            bn.weight.as_mut().unwrap().set_data(
6877                Tensor::from_storage(TensorStorage::cpu(gamma.clone()), vec![c], false).unwrap(),
6878            );
6879            bn.bias.as_mut().unwrap().set_data(
6880                Tensor::from_storage(TensorStorage::cpu(beta.clone()), vec![c], false).unwrap(),
6881            );
6882            bn.set_running_mean(&rmean).unwrap();
6883            bn.set_running_var(&rvar).unwrap();
6884            bn.eval();
6885            bn
6886        };
6887
6888        let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![b, c, h, w], false)
6889            .unwrap();
6890        let bn_cpu = make();
6891        let cpu_vals = bn_cpu.forward(&x_cpu).unwrap().data().unwrap().to_vec();
6892
6893        let mut bn_gpu = make();
6894        bn_gpu.to_device(Device::Cuda(0)).unwrap();
6895        let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
6896        let y_gpu = bn_gpu.forward(&x_gpu).unwrap();
6897        assert!(y_gpu.is_cuda(), "BatchNorm2d GPU output must stay on CUDA");
6898        let gpu_vals = y_gpu.data_vec().unwrap();
6899
6900        assert_eq!(gpu_vals.len(), cpu_vals.len());
6901        let mut max_abs = 0.0f32;
6902        for (g, cv) in gpu_vals.iter().zip(cpu_vals.iter()) {
6903            max_abs = max_abs.max((g - cv).abs());
6904        }
6905        assert!(
6906            max_abs < 1e-4,
6907            "BatchNorm2d eval GPU vs CPU max|Δ| = {max_abs}"
6908        );
6909    }
6910
6911    /// #1449: `BatchNorm2d::forward` in **train** mode on CUDA must compute the
6912    /// batch statistics on-device, match the CPU forward, AND fold the same
6913    /// running-mean / running-var update back (momentum + Bessel correction).
6914    #[test]
6915    #[cfg(feature = "cuda")]
6916    fn batch_norm2d_train_forward_gpu_matches_cpu() {
6917        use crate::module::Module as _;
6918        use ferrotorch_core::Device;
6919        use ferrotorch_gpu::init_cuda_backend;
6920        if init_cuda_backend().is_err() {
6921            return;
6922        }
6923
6924        let (b, c, h, w) = (4usize, 5usize, 3usize, 3usize);
6925        let n = b * c * h * w;
6926        let data: Vec<f32> = (0..n).map(|k| ((k as f32) * 0.037).sin() * 1.3).collect();
6927        let gamma: Vec<f32> = (0..c).map(|k| 0.9 + 0.04 * (k as f32)).collect();
6928        let beta: Vec<f32> = (0..c).map(|k| 0.02 * (k as f32)).collect();
6929
6930        let make = || {
6931            let mut bn = BatchNorm2d::<f32>::new(c, 1e-5, 0.1, true).unwrap();
6932            bn.weight.as_mut().unwrap().set_data(
6933                Tensor::from_storage(TensorStorage::cpu(gamma.clone()), vec![c], false).unwrap(),
6934            );
6935            bn.bias.as_mut().unwrap().set_data(
6936                Tensor::from_storage(TensorStorage::cpu(beta.clone()), vec![c], false).unwrap(),
6937            );
6938            bn // default training=true
6939        };
6940
6941        let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![b, c, h, w], false)
6942            .unwrap();
6943        let bn_cpu = make();
6944        let cpu_vals = bn_cpu.forward(&x_cpu).unwrap().data().unwrap().to_vec();
6945        let cpu_rmean = bn_cpu.running_mean();
6946        let cpu_rvar = bn_cpu.running_var();
6947
6948        let mut bn_gpu = make();
6949        bn_gpu.to_device(Device::Cuda(0)).unwrap();
6950        let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
6951        let y_gpu = bn_gpu.forward(&x_gpu).unwrap();
6952        assert!(y_gpu.is_cuda());
6953        let gpu_vals = y_gpu.data_vec().unwrap();
6954
6955        let mut max_abs = 0.0f32;
6956        for (g, cv) in gpu_vals.iter().zip(cpu_vals.iter()) {
6957            max_abs = max_abs.max((g - cv).abs());
6958        }
6959        assert!(
6960            max_abs < 1e-4,
6961            "BatchNorm2d train GPU vs CPU max|Δ| = {max_abs}"
6962        );
6963
6964        // Running-stat update must match the CPU path and increment the counter.
6965        let gpu_rmean = bn_gpu.running_mean();
6966        let gpu_rvar = bn_gpu.running_var();
6967        assert_eq!(bn_gpu.num_batches_tracked(), 1);
6968        for cc in 0..c {
6969            assert!(
6970                (gpu_rmean[cc] - cpu_rmean[cc]).abs() < 1e-4,
6971                "running_mean[{cc}] gpu={} cpu={}",
6972                gpu_rmean[cc],
6973                cpu_rmean[cc]
6974            );
6975            assert!(
6976                (gpu_rvar[cc] - cpu_rvar[cc]).abs() < 1e-4,
6977                "running_var[{cc}] gpu={} cpu={}",
6978                gpu_rvar[cc],
6979                cpu_rvar[cc]
6980            );
6981        }
6982    }
6983
6984    /// #1449: `InstanceNorm2d::forward` on CUDA routes through the GroupNorm
6985    /// kernel with `num_groups == num_channels` and must match the CPU path.
6986    #[test]
6987    #[cfg(feature = "cuda")]
6988    fn instance_norm2d_forward_gpu_matches_cpu() {
6989        use crate::module::Module as _;
6990        use ferrotorch_core::Device;
6991        use ferrotorch_gpu::init_cuda_backend;
6992        if init_cuda_backend().is_err() {
6993            return;
6994        }
6995
6996        let (b, c, h, w) = (3usize, 4usize, 5usize, 4usize);
6997        let n = b * c * h * w;
6998        let data: Vec<f32> = (0..n).map(|k| ((k % 23) as f32) * 0.09 - 0.8).collect();
6999        let gamma: Vec<f32> = (0..c).map(|k| 1.0 + 0.06 * (k as f32)).collect();
7000        let beta: Vec<f32> = (0..c).map(|k| -0.05 + 0.04 * (k as f32)).collect();
7001
7002        let make = || {
7003            let mut inorm = InstanceNorm2d::<f32>::new(c, 1e-5, true).unwrap();
7004            inorm.inner.weight.set_data(
7005                Tensor::from_storage(TensorStorage::cpu(gamma.clone()), vec![c], false).unwrap(),
7006            );
7007            inorm.inner.bias.set_data(
7008                Tensor::from_storage(TensorStorage::cpu(beta.clone()), vec![c], false).unwrap(),
7009            );
7010            inorm
7011        };
7012
7013        let x_cpu = Tensor::from_storage(TensorStorage::cpu(data.clone()), vec![b, c, h, w], false)
7014            .unwrap();
7015        let cpu_vals = make().forward(&x_cpu).unwrap().data().unwrap().to_vec();
7016
7017        let mut gpu = make();
7018        gpu.to_device(Device::Cuda(0)).unwrap();
7019        let x_gpu = x_cpu.to(Device::Cuda(0)).unwrap();
7020        let y_gpu = gpu.forward(&x_gpu).unwrap();
7021        assert!(
7022            y_gpu.is_cuda(),
7023            "InstanceNorm2d GPU output must stay on CUDA"
7024        );
7025        let gpu_vals = y_gpu.data_vec().unwrap();
7026
7027        let mut max_abs = 0.0f32;
7028        for (g, cv) in gpu_vals.iter().zip(cpu_vals.iter()) {
7029            max_abs = max_abs.max((g - cv).abs());
7030        }
7031        assert!(
7032            max_abs < 1e-4,
7033            "InstanceNorm2d GPU vs CPU max|Δ| = {max_abs}"
7034        );
7035    }
7036}