Skip to main content

ferrotorch_nn/
conv.rs

1//! Convolution layers: 1-D, 2-D, 3-D and their transposed variants.
2//!
3//! Implements `Conv1d<T>`, `Conv2d<T>`, `Conv3d<T>`, `ConvTranspose1d<T>`,
4//! `ConvTranspose2d<T>`, and `ConvTranspose3d<T>`.
5//! Forward passes use the im2col + matmul approach; backward follows the
6//! same structure in reverse.
7//!
8//! ## REQ status (per `.design/ferrotorch-nn/conv.md`)
9//!
10//! | REQ | Status | Evidence |
11//! |---|---|---|
12//! | REQ-1 | SHIPPED | impl: `pub struct Conv2d<T: Float>` here, mirroring `aten/src/ATen/native/Convolution.cpp:520-600`; non-test consumer: `ferrotorch-vision/src/models/resnet.rs` constructs `Conv2d::new(...)` for every residual block conv. |
13//! | REQ-2 | SHIPPED | impl: the `Conv2d::new` / `Conv2d::new_full` constructors here with `groups` / `dilation` validation; non-test consumer: `ferrotorch-vision/src/models/vit.rs` and `convnext.rs` construct grouped or dilated `Conv2d` via `new_full`. |
14//! | REQ-3 | SHIPPED | impl: `<Conv2d as Module>::forward` body here (im2col + matmul) mirroring `aten::convolution`; non-test consumer: every vision model forward invokes `Conv2d::forward` through its `Module` impl. |
15//! | REQ-4 | SHIPPED | impl: `is_f32 && input.is_cuda()` dispatch to `backend.conv2d_f32` inside `<Conv2d as Module>::forward`; non-test consumer: `ferrotorch-gpu/src/backend_impl.rs` exposes `Backend::conv2d_f32`; vision-model training runs on CUDA trigger this dispatch end-to-end. |
16//! | REQ-5 | SHIPPED | impl: `Conv2dBackward<T>: GradFn<T>` impl block here; non-test consumer: every gradient step on a vision model's `loss.backward()` traverses these `Conv2dBackward` nodes through `ferrotorch_core::autograd::engine`. |
17//! | REQ-6 | SHIPPED | impl: `pub struct Conv1d` / `Conv3d` / `ConvTranspose{1,2,3}d` here, each carrying `groups`/`dilation` via `*::new_full(.., dilation, groups, bias)`. Forward layers: per-group + dilated `<Conv1d as Module>::forward` / `<Conv3d as Module>::forward` + `Conv1dBackward` / `Conv3dBackward` (closes #1600 conv1d, #1601 conv3d). Transposed layers: `ConvTranspose{1,2,3}d::new_full` + the per-group helpers `conv_transpose2d_forward_group` / `conv_transpose3d_forward_group` + per-group loops in `<ConvTranspose*d as Module>::forward` + per-group/dilated `ConvTranspose{1,2,3}dBackward` (closes #1607 groups, #1608 dilation), plus the rank-`D+1` unbatched `unsqueeze`/recurse/`squeeze` guard atop each transposed `forward` (closes #1609). Transposed weight `[in_channels, out_channels/groups, *k]` per `torch/nn/modules/conv.py:164`; channel partition (input dim 1 / weight dim 0 / bias dim 0) per `aten/src/ATen/native/Convolution.cpp:1723-1729`; dilated internal conv `internal_pad = dilation*(k-1) - padding`, `eff_kernel = dilation*(k-1)+1` per `aten/src/ATen/native/ConvUtils.h:255`. non-test consumer: `Conv1d::new` / `Conv3d::new` / `ConvTranspose{1,2,3}d::new` delegate to `new_full` in production and are called by `ferrotorch-nn/src/lazy_conv.rs` `LazyConv{1,3}d::materialize` / `ferrotorch-nn/src/lazy_conv_transpose.rs` `LazyConvTranspose{1,2,3}d::materialize`; `ferrotorch-vision/src/models/detection/{mask_rcnn,keypoint_rcnn}.rs` construct `ConvTranspose2d::new`; `ferrotorch-vision/src/models/inception.rs` uses `Conv2d` + `ConvTranspose2d`. |
18//! | REQ-7 | SHIPPED | impl: `impl<T: Float> Module<T> for Conv2d<T>` block (and analogues for the other 5) here; non-test consumer: `ferrotorch_optim` walks `Module::parameters_mut()` across every conv in a training loop. |
19//! | REQ-8 | SHIPPED | impl: the `Conv2d::set_weight` and `Conv2d::from_parts` methods here; non-test consumer: `ferrotorch-nn/src/functional.rs` (the stateless `nn::functional::conv2d` entry point) uses `Conv2d::from_parts` to drive the existing forward path with user-supplied parameters. |
20//! | REQ-9 | SHIPPED | impl: `kaiming_uniform(&mut weight, NonLinearity::ReLU)` + `uniform_init(&mut b, -bound, bound)` (bound = 1/sqrt(fan_in)) inside every `Conv*d::new[_full]` here, mirroring `torch/nn/modules/conv.py:182-201`; non-test consumer: `Conv2d::new` is the path used by every vision-model constructor. (closes #1450 — bias U(-bound,bound). Kaiming gain divergence (`a=sqrt(5)` upstream vs `ReLU` here) remains as separate followup.) |
21//! | REQ-10 | SHIPPED | impl: `Conv1d` / `Conv2d` / `Conv3d` each carry a `padding_mode: crate::padding::PaddingMode` field + `with_padding_mode(...)` builder here; when the mode is non-`Zeros`, the layer's `forward` calls `crate::padding::functional_pad_1d`/`_2d`/`_3d` (with `_reversed_padding_repeated_twice` amounts) and then runs the zero-padding im2col on the already-padded tensor, mirroring `torch/nn/modules/conv.py` `_ConvNd._conv_forward` (Conv1d `conv.py:367-378`, Conv3d `conv.py:716-732`). The 1-D/3-D pre-pads are autograd-aware (`Pad1dBackward` / `Pad3dBackward` in `padding.rs`), so input gradients flow through the boundary (the #1550 fix class). `ConvTranspose{1,2,3}d::with_padding_mode` rejects any non-`Zeros` mode via `fn reject_non_zeros_transpose` with the upstream `ValueError('Only "zeros" padding mode is supported for ...')` (`conv.py:755-758`). Closes #1443. Non-test consumer: `pub use conv::{Conv1d, Conv2d, Conv3d, ConvTranspose1d, ConvTranspose2d, ConvTranspose3d}` in `lib.rs` re-export; the `<Conv1d as Module>::forward` / `<Conv3d as Module>::forward` bodies consume `functional_pad_1d` / `functional_pad_3d` in production. |
22//! | REQ-11 | NOT-STARTED | blocker #1441 (umbrella) — parity-sweep runner arms for all 6 conv ops are absent; sweep reports `0/N passed, N skipped` for each. The forward paths themselves are end-to-end verified by 60+ lib tests; only the runner-arm wiring is missing. |
23//! | REQ-12 | SHIPPED | impl: `pub enum StringPadding` + `fn same_pad_lr` + `Conv{1,2,3}d::with_string_padding` and the `string_padding` branch of each `<Conv*d as Module>::forward` here (asymmetric `'same'` pre-pad via `crate::padding::functional_pad_{1,2,3}d`, `left=total/2`/`right=total-left` per `aten/src/ATen/native/Pool.h:91-107`; `'valid'`==padding 0 per `aten/src/ATen/native/Convolution.cpp:1122-1124`; stride>1 `'same'` rejected per `torch/nn/modules/conv.py:117-120`). Non-test consumer: the production `Module::forward` bodies consume `same_pad_lr` + `functional_pad_{1,2,3}d` + `recurse_clone`; `Conv{1,2,3}d` re-exported from `lib.rs`. Closes #1602. |
24//! | REQ-13 | SHIPPED | impl: the unbatched `input.ndim()` guard at the top of each `<Conv*d as Module>::forward` here (`unsqueeze(0)` → recurse → `squeeze(0)` via `ferrotorch_core::grad_fns::shape::{unsqueeze, squeeze}`), mirroring `batchify` + `output.squeeze(0)` at `aten/src/ATen/native/Convolution.cpp:816-831, 990-1047`. Non-test consumer: the production `Module::forward` bodies call `unsqueeze`/`squeeze`; `Conv{1,2,3}d` re-exported from `lib.rs`. Closes #1604. |
25
26use std::sync::Arc;
27
28use ferrotorch_core::autograd::autocast_ops::autocast_guard;
29use ferrotorch_core::autograd::no_grad::is_grad_enabled;
30use ferrotorch_core::grad_fns::shape::{squeeze, unsqueeze};
31use ferrotorch_core::ops::linalg::{mm, transpose};
32use ferrotorch_core::storage::TensorStorage;
33use ferrotorch_core::tensor::{GradFn, Tensor};
34use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
35
36use crate::init::{NonLinearity, kaiming_uniform, uniform as uniform_init};
37use crate::module::Module;
38use crate::parameter::Parameter;
39
40// ---------------------------------------------------------------------------
41// ConvTranspose padding_mode validation
42// ---------------------------------------------------------------------------
43
44/// Reject any non-`Zeros` `padding_mode` for a transposed convolution.
45///
46/// Upstream `_ConvTransposeNd.__init__` (`torch/nn/modules/conv.py:755-758`)
47/// runs `if padding_mode != "zeros": raise ValueError(f'Only "zeros" padding
48/// mode is supported for {self.__class__.__name__}')`. Only `"zeros"` is a
49/// valid `padding_mode` for ConvTranspose layers; matching this exception
50/// behaviour (rather than silently accepting the mode) is the R-DEV-2 contract.
51/// Closes #1443.
52fn reject_non_zeros_transpose(
53    mode: crate::padding::PaddingMode,
54    class_name: &str,
55) -> FerrotorchResult<()> {
56    if mode != crate::padding::PaddingMode::Zeros {
57        return Err(FerrotorchError::InvalidArgument {
58            message: format!("Only \"zeros\" padding mode is supported for {class_name}"),
59        });
60    }
61    Ok(())
62}
63
64// ---------------------------------------------------------------------------
65// String padding ('same' / 'valid')  — #1602
66// ---------------------------------------------------------------------------
67
68/// The string-padding modes a `Conv{1,2,3}d` may be configured with, mirroring
69/// the `padding: str` branch of `torch.nn.Conv{1,2,3}d`
70/// (`torch/nn/modules/conv.py:111-120`, `valid_padding_strings = {"same",
71/// "valid"}`).
72///
73/// - [`StringPadding::Valid`] is equivalent to `padding = 0` (no padding;
74///   `aten/src/ATen/native/Convolution.cpp:1122-1124`
75///   `padding == "valid" -> convolution_symint(.., {{0}}, ..)`).
76/// - [`StringPadding::Same`] pads so that, for `stride = 1`, the output spatial
77///   size equals the input spatial size. The total pad per dim is
78///   `dilation * (kernel - 1)`, split ASYMMETRICALLY as
79///   `left = total / 2`, `right = total - left` (the END gets the extra pad
80///   when `total` is odd), mirroring `_pooling_same_mode_padding_lr`
81///   (`aten/src/ATen/native/Pool.h:91-107`) and the matching
82///   `_ConvNd.__init__` `_reversed_padding_repeated_twice` arithmetic
83///   (`conv.py:143-155`). `'same'` is rejected for strided convolutions
84///   (`conv.py:117-120` / `Convolution.cpp:1071`).
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
86pub enum StringPadding {
87    /// `padding='same'` — pad so output spatial size == input spatial size
88    /// (stride must be 1). Asymmetric split per [`same_pad_lr`].
89    Same,
90    /// `padding='valid'` — no padding (equivalent to `padding = 0`).
91    Valid,
92}
93
94/// Compute the asymmetric `(left, right)` `'same'` padding for one spatial dim.
95///
96/// Mirrors `_pooling_same_mode_padding_lr` in
97/// `aten/src/ATen/native/Pool.h:91-107`:
98///
99/// ```text
100/// total_padding = dilation * (kernel_size - 1)
101/// left  = total_padding / 2          // floor
102/// right = total_padding - left
103/// ```
104///
105/// The `stride > 2` wiggle-room branch of the upstream helper is unreachable
106/// here because `'same'` requires `stride == 1` (validated at construction,
107/// `conv.py:117-120`). The right side therefore receives the extra unit of
108/// padding whenever `total_padding` is odd — this is the exact same arithmetic
109/// the Python `_ConvNd.__init__` runs to populate
110/// `_reversed_padding_repeated_twice` for the `'same'` case
111/// (`conv.py:150-155`).
112fn same_pad_lr(kernel_size: usize, dilation: usize) -> (usize, usize) {
113    let total_padding = dilation * (kernel_size - 1);
114    let left = total_padding / 2;
115    (left, total_padding - left)
116}
117
118// ---------------------------------------------------------------------------
119// im2col / col2im helpers
120// ---------------------------------------------------------------------------
121
122/// Extract image patches into columns (no dilation — calls [`im2col_dilated`]
123/// with `(1, 1)` for the dilation rate).
124///
125/// Given a 4-D input `[B, C, H, W]`, produces a 3-D output
126/// `[B, C * kH * kW, H_out * W_out]` where each column is one
127/// flattened receptive-field patch.
128// Internal kernel: argument set mirrors the 2-D convolution descriptor
129// (B, C, H, W, kH, kW, padH, padW, strideH, strideW); a config
130// struct would force allocation on every call in convolution hot paths.
131#[allow(clippy::too_many_arguments)]
132fn im2col<T: Float>(
133    input: &[T],
134    batch: usize,
135    channels: usize,
136    height: usize,
137    width: usize,
138    kernel_h: usize,
139    kernel_w: usize,
140    stride_h: usize,
141    stride_w: usize,
142    pad_h: usize,
143    pad_w: usize,
144) -> (Vec<T>, usize, usize) {
145    im2col_dilated(
146        input, batch, channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_h,
147        pad_w, 1, 1,
148    )
149}
150
151/// Extract image patches into columns, supporting dilation `(dil_h, dil_w)`.
152///
153/// Given a 4-D input `[B, C, H, W]`, produces a 3-D output
154/// `[B, C * kH * kW, H_out * W_out]` where each column is one flattened
155/// receptive-field patch with kernel taps spaced by `dil_h`/`dil_w` along the
156/// spatial axes.
157///
158/// Output spatial sizes follow the standard formula:
159///
160/// ```text
161/// H_out = (H + 2*pad_h - dil_h*(kH - 1) - 1) / stride_h + 1
162/// W_out = (W + 2*pad_w - dil_w*(kW - 1) - 1) / stride_w + 1
163/// ```
164// Internal kernel: argument set mirrors the 2-D convolution descriptor
165// (B, C, H, W, kH, kW, strideH, strideW, padH, padW, dilH, dilW); a config
166// struct would force allocation on every call in convolution hot paths.
167#[allow(clippy::too_many_arguments)]
168fn im2col_dilated<T: Float>(
169    input: &[T],
170    batch: usize,
171    channels: usize,
172    height: usize,
173    width: usize,
174    kernel_h: usize,
175    kernel_w: usize,
176    stride_h: usize,
177    stride_w: usize,
178    pad_h: usize,
179    pad_w: usize,
180    dil_h: usize,
181    dil_w: usize,
182) -> (Vec<T>, usize, usize) {
183    let eff_kh = dil_h * (kernel_h - 1) + 1;
184    let eff_kw = dil_w * (kernel_w - 1) + 1;
185    let h_out = (height + 2 * pad_h - eff_kh) / stride_h + 1;
186    let w_out = (width + 2 * pad_w - eff_kw) / stride_w + 1;
187    let col_rows = channels * kernel_h * kernel_w;
188    let col_cols = h_out * w_out;
189
190    let zero = <T as num_traits::Zero>::zero();
191    let mut cols = vec![zero; batch * col_rows * col_cols];
192
193    for b in 0..batch {
194        for c in 0..channels {
195            for kh in 0..kernel_h {
196                for kw in 0..kernel_w {
197                    let row = c * kernel_h * kernel_w + kh * kernel_w + kw;
198                    for oh in 0..h_out {
199                        for ow in 0..w_out {
200                            // The padded-coordinate of this kernel tap.
201                            let ih = oh * stride_h + kh * dil_h;
202                            let iw = ow * stride_w + kw * dil_w;
203                            let col = oh * w_out + ow;
204
205                            // Account for padding: the "virtual" input coordinate
206                            // must be shifted back by the padding amount.
207                            let val = if ih >= pad_h
208                                && iw >= pad_w
209                                && (ih - pad_h) < height
210                                && (iw - pad_w) < width
211                            {
212                                let real_h = ih - pad_h;
213                                let real_w = iw - pad_w;
214                                input[b * channels * height * width
215                                    + c * height * width
216                                    + real_h * width
217                                    + real_w]
218                            } else {
219                                zero
220                            };
221
222                            cols[b * col_rows * col_cols + row * col_cols + col] = val;
223                        }
224                    }
225                }
226            }
227        }
228    }
229
230    (cols, col_rows, col_cols)
231}
232
233/// Scatter columns back into an image tensor (adjoint of [`im2col`]).
234///
235/// Given columns of shape `[B, C * kH * kW, H_out * W_out]`, accumulates
236/// values back into a `[B, C, H, W]` tensor (with padding stripped).
237///
238/// `#[cfg(test)]`-gated: production backward paths (`Conv1dBackward`,
239/// `Conv2dBackward`) call [`col2im_dilated`] directly with the layer's
240/// dilation, so the only remaining caller of this non-dilated shim is the
241/// im2col/col2im roundtrip unit test.
242// Internal kernel: argument set is the adjoint of `im2col` (same descriptor
243// inputs); refactoring to a config struct would diverge the two helpers'
244// signatures unhelpfully.
245#[cfg(test)]
246#[allow(clippy::too_many_arguments)]
247fn col2im<T: Float>(
248    cols: &[T],
249    batch: usize,
250    channels: usize,
251    height: usize,
252    width: usize,
253    kernel_h: usize,
254    kernel_w: usize,
255    stride_h: usize,
256    stride_w: usize,
257    pad_h: usize,
258    pad_w: usize,
259    h_out: usize,
260    w_out: usize,
261) -> Vec<T> {
262    col2im_dilated(
263        cols, batch, channels, height, width, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
264        1, 1, h_out, w_out,
265    )
266}
267
268/// Scatter columns back into an image tensor with dilation support
269/// (adjoint of [`im2col_dilated`]).
270///
271/// Given columns of shape `[B, C * kH * kW, H_out * W_out]`, accumulates
272/// values back into a `[B, C, H, W]` tensor (with padding stripped),
273/// honouring `dil_h`/`dil_w` so kernel taps are spaced apart in the input.
274// Internal kernel: argument set is the adjoint of `im2col_dilated` (same
275// descriptor inputs); refactoring to a config struct would diverge the two
276// helpers' signatures unhelpfully.
277#[allow(clippy::too_many_arguments)]
278fn col2im_dilated<T: Float>(
279    cols: &[T],
280    batch: usize,
281    channels: usize,
282    height: usize,
283    width: usize,
284    kernel_h: usize,
285    kernel_w: usize,
286    stride_h: usize,
287    stride_w: usize,
288    pad_h: usize,
289    pad_w: usize,
290    dil_h: usize,
291    dil_w: usize,
292    h_out: usize,
293    w_out: usize,
294) -> Vec<T> {
295    let zero = <T as num_traits::Zero>::zero();
296    let mut output = vec![zero; batch * channels * height * width];
297
298    let col_rows = channels * kernel_h * kernel_w;
299    let col_cols = h_out * w_out;
300
301    for b in 0..batch {
302        for c in 0..channels {
303            for kh in 0..kernel_h {
304                for kw in 0..kernel_w {
305                    let row = c * kernel_h * kernel_w + kh * kernel_w + kw;
306                    for oh in 0..h_out {
307                        for ow in 0..w_out {
308                            let ih = oh * stride_h + kh * dil_h;
309                            let iw = ow * stride_w + kw * dil_w;
310                            let col = oh * w_out + ow;
311
312                            if ih >= pad_h
313                                && iw >= pad_w
314                                && (ih - pad_h) < height
315                                && (iw - pad_w) < width
316                            {
317                                let real_h = ih - pad_h;
318                                let real_w = iw - pad_w;
319                                output[b * channels * height * width
320                                    + c * height * width
321                                    + real_h * width
322                                    + real_w] +=
323                                    cols[b * col_rows * col_cols + row * col_cols + col];
324                            }
325                        }
326                    }
327                }
328            }
329        }
330    }
331
332    output
333}
334
335// ---------------------------------------------------------------------------
336// Conv2d
337// ---------------------------------------------------------------------------
338
339/// A 2-D convolution layer.
340///
341/// Applies a spatial convolution over an input `[B, C_in, H, W]` using
342/// the im2col + matmul algorithm. Equivalent to `torch.nn.Conv2d`,
343/// including the `groups` and `dilation` arguments (see
344/// [`Conv2d::new_full`]).
345///
346/// # Shape
347///
348/// - Input: `[B, in_channels, H, W]`
349/// - Output: `[B, out_channels, H_out, W_out]`
350///
351/// where `H_out = (H + 2 * padding.0 - dilation.0 * (kernel_size.0 - 1) - 1)
352/// / stride.0 + 1`.
353///
354/// # GPU coverage
355///
356/// The CUDA fast path supplied by `ferrotorch-gpu` currently only covers
357/// `groups == 1 && dilation == (1, 1)`. When the layer is constructed with
358/// `groups > 1` or `dilation != (1, 1)`, [`Module::forward`] explicitly
359/// skips the GPU dispatch at the gate (see the `if groups == 1 && dilation
360/// == (1, 1)` guard in the forward) and runs the entire convolution on the
361/// CPU. Wiring `groups`/`dilation` through the GPU backend signature is
362/// tracked separately as a backend extension issue.
363#[derive(Debug)]
364pub struct Conv2d<T: Float> {
365    /// Learnable kernel weights `[out_channels, in_channels / groups, kH, kW]`.
366    weight: Parameter<T>,
367    /// Optional learnable bias `[out_channels]`.
368    bias: Option<Parameter<T>>,
369    /// Number of input channels.
370    in_channels: usize,
371    /// Number of output channels (filters).
372    out_channels: usize,
373    /// Kernel spatial size `(kH, kW)`.
374    kernel_size: (usize, usize),
375    /// Stride `(sH, sW)`.
376    stride: (usize, usize),
377    /// Zero-padding `(pH, pW)` applied to both sides.
378    padding: (usize, usize),
379    /// Dilation `(dilH, dilW)`. `(1, 1)` is the dense default.
380    dilation: (usize, usize),
381    /// Number of blocked input/output channel groups. `1` is dense, `in_channels`
382    /// is depthwise. Must divide both `in_channels` and `out_channels`.
383    groups: usize,
384    /// Boundary handling for the spatial padding. `Zeros` (default) routes
385    /// through the existing im2col fast path; non-`Zeros` modes pre-pad
386    /// the input via `crate::padding::functional_pad_2d` and then run the
387    /// dense im2col over the already-padded tensor (matching the upstream
388    /// `_ConvNd._conv_forward` shape: `F.pad(input, ..., mode=...)` first,
389    /// then a `padding=0` convolution). Closes #1443.
390    padding_mode: crate::padding::PaddingMode,
391    /// String padding mode (`'same'` / `'valid'`), `None` when numeric
392    /// `padding` is used. When `Some`, the `padding` field is ignored and the
393    /// effective padding is derived per [`StringPadding`] in `forward`
394    /// (mirroring the `padding: str` branch of `torch.nn.Conv2d`,
395    /// `torch/nn/modules/conv.py:111-155`). Set via
396    /// [`Conv2d::with_string_padding`]. Closes #1602.
397    string_padding: Option<StringPadding>,
398    /// Whether the module is in training mode.
399    training: bool,
400}
401
402impl<T: Float> Conv2d<T> {
403    /// Create a new `Conv2d` layer (dense, dilation `(1, 1)`, `groups = 1`).
404    ///
405    /// Weight is initialized with Kaiming uniform (ReLU gain).
406    /// Bias, if enabled, is initialized U(-bound, bound) with
407    /// `bound = 1/sqrt(fan_in)` per `torch/nn/modules/conv.py:198-201`.
408    ///
409    /// This is a thin shim over [`Conv2d::new_full`] preserved for
410    /// backwards compatibility with existing callers (see Phase 5 of #1002).
411    pub fn new(
412        in_channels: usize,
413        out_channels: usize,
414        kernel_size: (usize, usize),
415        stride: (usize, usize),
416        padding: (usize, usize),
417        bias: bool,
418    ) -> FerrotorchResult<Self> {
419        Self::new_full(
420            in_channels,
421            out_channels,
422            kernel_size,
423            stride,
424            padding,
425            (1, 1),
426            1,
427            bias,
428        )
429    }
430
431    /// Create a new `Conv2d` layer with the full PyTorch-shaped argument set,
432    /// including `dilation` and `groups`.
433    ///
434    /// `groups` must divide BOTH `in_channels` and `out_channels` (PyTorch
435    /// `torch.nn.Conv2d` raises `ValueError` otherwise). `dilation` must be
436    /// strictly positive in both dimensions. Weight is initialised with
437    /// Kaiming uniform (ReLU gain); bias (if enabled) with U(-bound, bound)
438    /// where `bound = 1/sqrt(fan_in)` per `torch/nn/modules/conv.py:198-201`.
439    ///
440    /// # GPU coverage caveat
441    ///
442    /// `Conv2d::forward`'s CUDA fast path is only taken when `groups == 1 &&
443    /// dilation == (1, 1)`. With grouped or dilated configurations the
444    /// dispatch gate explicitly falls through to the CPU implementation;
445    /// kernel-side `groups`/`dilation` plumbing in the `ferrotorch-gpu`
446    /// backend is a separate workitem.
447    #[allow(clippy::too_many_arguments)]
448    pub fn new_full(
449        in_channels: usize,
450        out_channels: usize,
451        kernel_size: (usize, usize),
452        stride: (usize, usize),
453        padding: (usize, usize),
454        dilation: (usize, usize),
455        groups: usize,
456        bias: bool,
457    ) -> FerrotorchResult<Self> {
458        if in_channels == 0 || out_channels == 0 {
459            return Err(FerrotorchError::InvalidArgument {
460                message: "in_channels and out_channels must be > 0".into(),
461            });
462        }
463        if kernel_size.0 == 0 || kernel_size.1 == 0 {
464            return Err(FerrotorchError::InvalidArgument {
465                message: "kernel_size must be > 0 in both dimensions".into(),
466            });
467        }
468        if stride.0 == 0 || stride.1 == 0 {
469            return Err(FerrotorchError::InvalidArgument {
470                message: "stride must be > 0 in both dimensions".into(),
471            });
472        }
473        if dilation.0 == 0 || dilation.1 == 0 {
474            return Err(FerrotorchError::InvalidArgument {
475                message: format!(
476                    "Conv2d::new_full: dilation must be > 0 in both dimensions, got {dilation:?}"
477                ),
478            });
479        }
480        if groups == 0 {
481            return Err(FerrotorchError::InvalidArgument {
482                message: "Conv2d::new_full: groups must be > 0".into(),
483            });
484        }
485        if in_channels % groups != 0 {
486            return Err(FerrotorchError::InvalidArgument {
487                message: format!(
488                    "Conv2d::new_full: groups={groups} must divide in_channels={in_channels}"
489                ),
490            });
491        }
492        if out_channels % groups != 0 {
493            return Err(FerrotorchError::InvalidArgument {
494                message: format!(
495                    "Conv2d::new_full: groups={groups} must divide out_channels={out_channels}"
496                ),
497            });
498        }
499
500        let (kh, kw) = kernel_size;
501        // PyTorch weight layout is [C_out, C_in / groups, kH, kW].
502        let mut weight = Parameter::zeros(&[out_channels, in_channels / groups, kh, kw])?;
503        kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
504
505        let bias_param = if bias {
506            let mut b = Parameter::zeros(&[out_channels])?;
507            // `torch/nn/modules/conv.py:198-201`: `fan_in, _ = init._calculate_fan_in_and_fan_out(weight);
508            //   bound = 1 / sqrt(fan_in); init.uniform_(self.bias, -bound, bound)`. For Conv2d
509            //   `fan_in = (in_channels/groups) * kH * kW`.
510            let fan_in = (in_channels / groups) * kh * kw;
511            let bound = if fan_in > 0 {
512                1.0 / (fan_in as f64).sqrt()
513            } else {
514                0.0
515            };
516            uniform_init(&mut b, -bound, bound)?;
517            Some(b)
518        } else {
519            None
520        };
521
522        Ok(Self {
523            weight,
524            bias: bias_param,
525            in_channels,
526            out_channels,
527            kernel_size,
528            stride,
529            padding,
530            dilation,
531            groups,
532            padding_mode: crate::padding::PaddingMode::Zeros,
533            string_padding: None,
534            training: true,
535        })
536    }
537
538    /// Configure the boundary handling for the spatial padding.
539    ///
540    /// `Zeros` (default) uses the existing im2col zero-pad path.
541    /// `Reflect`, `Replicate`, and `Circular` pre-pad the input via
542    /// `crate::padding::functional_pad_2d(input, ...)` and then convolve
543    /// with `padding = 0`, matching `torch.nn.Conv2d(..., padding_mode=...)`
544    /// (`_ConvNd._conv_forward`'s `F.pad` shape). Closes #1443.
545    pub fn with_padding_mode(mut self, mode: crate::padding::PaddingMode) -> Self {
546        self.padding_mode = mode;
547        self
548    }
549
550    /// Configure string padding (`'same'` / `'valid'`), mirroring the
551    /// `padding: str` branch of `torch.nn.Conv2d` (`conv.py:111-155`).
552    ///
553    /// `StringPadding::Valid` is equivalent to `padding = 0`.
554    /// `StringPadding::Same` pads so the output spatial size equals the input
555    /// spatial size (for `stride = 1`), splitting the per-dim total
556    /// `dilation * (kernel - 1)` asymmetrically as `left = total/2`,
557    /// `right = total - left` (the END gets the extra unit; see
558    /// [`same_pad_lr`]). The pre-pad uses the configured `padding_mode`
559    /// (constant-0 for the default `Zeros`, matching
560    /// `convolution_same`'s `constant_pad_nd(..., 0)` at
561    /// `Convolution.cpp:1105`) and is autograd-aware via `Pad2dBackward`.
562    ///
563    /// Returns `Err` if `StringPadding::Same` is requested with a stride other
564    /// than 1 in any dimension, matching upstream
565    /// `raise ValueError("padding='same' is not supported for strided
566    /// convolutions")` (`conv.py:117-120`, `Convolution.cpp:1071`). Closes
567    /// #1602.
568    pub fn with_string_padding(mut self, padding: StringPadding) -> FerrotorchResult<Self> {
569        if padding == StringPadding::Same && (self.stride.0 != 1 || self.stride.1 != 1) {
570            return Err(FerrotorchError::InvalidArgument {
571                message: "padding='same' is not supported for strided convolutions".into(),
572            });
573        }
574        self.string_padding = Some(padding);
575        self.padding = (0, 0);
576        Ok(self)
577    }
578
579    /// Replace the kernel weights with a caller-supplied [`Parameter`].
580    ///
581    /// The new weight must have shape `[out_channels, in_channels / groups,
582    /// kH, kW]` (i.e. the same shape as the existing parameter). Used by
583    /// tests and tooling that need deterministic weights without going
584    /// through [`Conv2d::from_parts`].
585    pub fn set_weight(&mut self, weight: Parameter<T>) -> FerrotorchResult<()> {
586        let expected = [
587            self.out_channels,
588            self.in_channels / self.groups,
589            self.kernel_size.0,
590            self.kernel_size.1,
591        ];
592        let got = weight.tensor().shape();
593        if got != expected {
594            return Err(FerrotorchError::ShapeMismatch {
595                message: format!("Conv2d::set_weight: expected shape {expected:?}, got {got:?}"),
596            });
597        }
598        self.weight = weight;
599        Ok(())
600    }
601
602    /// Number of channel groups (`1` is dense, `in_channels` is depthwise).
603    pub fn groups(&self) -> usize {
604        self.groups
605    }
606
607    /// Dilation `(dilH, dilW)` (`(1, 1)` is the dense default).
608    pub fn dilation(&self) -> (usize, usize) {
609        self.dilation
610    }
611
612    /// The number of learnable scalar parameters.
613    ///
614    /// For grouped convolutions the weight tensor has shape
615    /// `[out_channels, in_channels / groups, kH, kW]` so the count is
616    /// scaled by `1 / groups`.
617    pub fn num_parameters(&self) -> usize {
618        let w = self.out_channels
619            * (self.in_channels / self.groups)
620            * self.kernel_size.0
621            * self.kernel_size.1;
622        let b = if self.bias.is_some() {
623            self.out_channels
624        } else {
625            0
626        };
627        w + b
628    }
629
630    /// Build a `Conv2d` from caller-supplied weight and optional bias tensors.
631    ///
632    /// `weight` must have shape `[out_channels, in_channels, kH, kW]`. If
633    /// `bias` is provided, it must be 1-D of length `out_channels`. The
634    /// stride and padding are passed through unchanged; the resulting layer
635    /// is dense (`groups = 1`, `dilation = (1, 1)`) so this constructor is
636    /// API-compatible with the pre-Phase-5 surface. This is the constructor
637    /// used by `nn::functional::conv2d` so callers can drive the existing
638    /// im2col + matmul forward path with their own parameters (e.g. for
639    /// stateless functional dispatch or weight sharing across modules).
640    pub fn from_parts(
641        weight: Tensor<T>,
642        bias: Option<Tensor<T>>,
643        stride: (usize, usize),
644        padding: (usize, usize),
645    ) -> FerrotorchResult<Self> {
646        if weight.ndim() != 4 {
647            return Err(FerrotorchError::ShapeMismatch {
648                message: format!(
649                    "Conv2d::from_parts: weight must be 4-D [out, in, kH, kW], got {:?}",
650                    weight.shape()
651                ),
652            });
653        }
654        let out_channels = weight.shape()[0];
655        let in_channels = weight.shape()[1];
656        let kernel_size = (weight.shape()[2], weight.shape()[3]);
657        if let Some(b) = &bias {
658            if b.ndim() != 1 || b.shape()[0] != out_channels {
659                return Err(FerrotorchError::ShapeMismatch {
660                    message: format!(
661                        "Conv2d::from_parts: bias shape {:?} != [{}]",
662                        b.shape(),
663                        out_channels
664                    ),
665                });
666            }
667        }
668        Ok(Self {
669            weight: Parameter::new(weight),
670            bias: bias.map(Parameter::new),
671            in_channels,
672            out_channels,
673            kernel_size,
674            stride,
675            padding,
676            dilation: (1, 1),
677            groups: 1,
678            padding_mode: crate::padding::PaddingMode::Zeros,
679            string_padding: None,
680            training: true,
681        })
682    }
683}
684
685impl<T: Float> Conv2d<T> {
686    /// Build a shallow clone of this layer with the geometry fields
687    /// overridden (used by `forward` to recurse onto the dense
688    /// zero-padding im2col path after a string-padding / non-zero
689    /// `padding_mode` pre-pad). The weight/bias `Parameter`s are re-wrapped
690    /// (cheap `Arc` clone of the underlying tensor storage); `string_padding`
691    /// is cleared so the recursion runs the numeric-padding path.
692    fn recurse_clone(
693        &self,
694        padding: (usize, usize),
695        padding_mode: crate::padding::PaddingMode,
696    ) -> Conv2d<T> {
697        Conv2d {
698            weight: Parameter::new(self.weight.tensor().clone()),
699            bias: self
700                .bias
701                .as_ref()
702                .map(|b| Parameter::new(b.tensor().clone())),
703            in_channels: self.in_channels,
704            out_channels: self.out_channels,
705            kernel_size: self.kernel_size,
706            stride: self.stride,
707            padding,
708            dilation: self.dilation,
709            groups: self.groups,
710            padding_mode,
711            string_padding: None,
712            training: self.training,
713        }
714    }
715}
716
717impl<T: Float> Module<T> for Conv2d<T> {
718    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
719        // Record autocast decision for conv2d.
720        let _autocast_cat = autocast_guard("conv2d");
721
722        // Unbatched input: `(C, H, W)` (rank 3) is accepted in addition to the
723        // batched `(N, C, H, W)` (rank 4) form. Mirrors `batchify` /
724        // `_conv_forward` at `aten/src/ATen/native/Convolution.cpp:816-831,
725        // 1015-1022`: an unbatched input is `unsqueeze(0)`d to add a batch
726        // dim, convolved, then `squeeze(0)`d so the output is rank 3. The
727        // unsqueeze/squeeze are autograd-aware (`UnsqueezeBackward` /
728        // `SqueezeBackward`) so gradients flow back to the unbatched shape.
729        // Closes #1604.
730        if input.ndim() == 3 {
731            let batched = unsqueeze(input, 0)?;
732            let output = self.forward(&batched)?;
733            return squeeze(&output, 0);
734        }
735
736        // String padding ('same' / 'valid'), mirroring the `padding: str`
737        // branch of `torch.nn.Conv2d` (`conv.py:111-155`,
738        // `Convolution.cpp:1119-1124`). `Valid` == numeric `padding = 0`;
739        // `Same` pre-pads asymmetrically (`left = total/2`, `right = total -
740        // left`, the END side getting the extra unit) via the autograd-aware
741        // `functional_pad_2d` then convolves with `padding = 0` — exactly the
742        // `convolution_same` -> `constant_pad_nd(.., 0)` path at
743        // `Convolution.cpp:1098-1108`. The pre-pad fill follows the configured
744        // `padding_mode` (constant-0 for the default `Zeros`). The stride>1
745        // rejection already happened at `with_string_padding` construction
746        // (`conv.py:117-120`). Closes #1602.
747        if let Some(sp) = self.string_padding {
748            match sp {
749                StringPadding::Valid => {
750                    // 'valid' == no padding.
751                    return self
752                        .recurse_clone((0, 0), crate::padding::PaddingMode::Zeros)
753                        .forward(input);
754                }
755                StringPadding::Same => {
756                    let (kh, kw) = self.kernel_size;
757                    let (dh, dw) = self.dilation;
758                    let (top, bottom) = same_pad_lr(kh, dh);
759                    let (left, right) = same_pad_lr(kw, dw);
760                    let padded = crate::padding::functional_pad_2d(
761                        input,
762                        left,
763                        right,
764                        top,
765                        bottom,
766                        self.padding_mode,
767                        <T as num_traits::Zero>::zero(),
768                    )?;
769                    return self
770                        .recurse_clone((0, 0), crate::padding::PaddingMode::Zeros)
771                        .forward(&padded);
772                }
773            }
774        }
775
776        // Non-zero padding modes: pre-pad the input with the requested
777        // boundary mode and then convolve with padding = 0. Mirrors
778        // `torch/nn/modules/conv.py` `_ConvNd._conv_forward`:
779        //   if self.padding_mode != 'zeros':
780        //       input = F.pad(input,
781        //                     self._reversed_padding_repeated_twice,
782        //                     mode=self.padding_mode)
783        //       conv2d(..., padding=0, ...)
784        // Closes #1443.
785        if self.padding_mode != crate::padding::PaddingMode::Zeros
786            && (self.padding.0 != 0 || self.padding.1 != 0)
787        {
788            let padded = crate::padding::functional_pad_2d(
789                input,
790                self.padding.1,
791                self.padding.1,
792                self.padding.0,
793                self.padding.0,
794                self.padding_mode,
795                <T as num_traits::Zero>::zero(),
796            )?;
797            // Recurse on a zero-padding variant. Build a shallow clone with
798            // padding = (0, 0) and padding_mode = Zeros so the existing
799            // im2col-on-zero-pad path runs without re-padding.
800            return self
801                .recurse_clone((0, 0), crate::padding::PaddingMode::Zeros)
802                .forward(&padded);
803        }
804
805        // Validate input shape: [B, C_in, H, W].
806        if input.ndim() != 4 {
807            return Err(FerrotorchError::InvalidArgument {
808                message: format!(
809                    "Conv2d expects 4-D input [B, C, H, W], got {:?}",
810                    input.shape()
811                ),
812            });
813        }
814
815        let batch = input.shape()[0];
816        let c_in = input.shape()[1];
817        let h = input.shape()[2];
818        let w = input.shape()[3];
819
820        if c_in != self.in_channels {
821            return Err(FerrotorchError::ShapeMismatch {
822                message: format!(
823                    "Conv2d: expected {} input channels, got {}",
824                    self.in_channels, c_in
825                ),
826            });
827        }
828
829        let (kh, kw) = self.kernel_size;
830        let (sh, sw) = self.stride;
831        let (ph, pw) = self.padding;
832        let (dh, dw) = self.dilation;
833        let groups = self.groups;
834
835        // Effective kernel extent after dilation.
836        let eff_kh = dh * (kh - 1) + 1;
837        let eff_kw = dw * (kw - 1) + 1;
838
839        // Check that the (effective) kernel fits.
840        let h_padded = h + 2 * ph;
841        let w_padded = w + 2 * pw;
842        if h_padded < eff_kh || w_padded < eff_kw {
843            return Err(FerrotorchError::InvalidArgument {
844                message: format!(
845                    "Conv2d: padded input ({h_padded}, {w_padded}) is smaller than effective kernel ({eff_kh}, {eff_kw})"
846                ),
847            });
848        }
849
850        let h_out = (h_padded - eff_kh) / sh + 1;
851        let w_out = (w_padded - eff_kw) / sw + 1;
852
853        // Save the input device so we can restore it on the output.
854        let input_device = input.device();
855
856        // ---- GPU fast path: fully on-device conv2d ----
857        //
858        // Pass 2A (#1003): the CUDA backend's `conv2d_f32` accepts groups
859        // and dilation natively. Every f32 CUDA input dispatches to the
860        // GPU regardless of `groups` / `dilation`; the kernel does the
861        // per-group im2col + GEMM on-device. The pre-Pass-2A
862        // `gpu_eligible = groups == 1 && dilation == (1, 1)` gate is
863        // gone — keeping it would be a stub-shaped CPU detour that
864        // failure mode #15 explicitly forbids.
865        //
866        // Note: the weight layout passed to the backend is
867        // `[C_out, C_in / groups, kH, kW]` — the PyTorch grouped-conv
868        // convention. `Conv2d::new_full` already constructs `self.weight`
869        // in that shape (see `Conv2d::new_full` for the `in_per_group =
870        // in_channels / groups` slice).
871        let is_f32 = std::mem::size_of::<T>() == 4;
872        if is_f32 && input.is_cuda() {
873            if let Some(backend) = ferrotorch_core::gpu_dispatch::gpu_backend() {
874                let bias_handle = self
875                    .bias
876                    .as_ref()
877                    .and_then(|b| b.tensor().gpu_handle().ok());
878                let weight_shape = self.weight.tensor().shape();
879                let weight_dim4: [usize; 4] = [
880                    weight_shape[0],
881                    weight_shape[1],
882                    weight_shape[2],
883                    weight_shape[3],
884                ];
885                let (out_handle, out_shape) = backend.conv2d_f32(
886                    input.gpu_handle()?,
887                    self.weight.tensor().gpu_handle()?,
888                    bias_handle,
889                    [batch, c_in, h, w],
890                    weight_dim4,
891                    self.stride,
892                    self.padding,
893                    self.dilation,
894                    groups,
895                )?;
896
897                let result = Tensor::from_storage(
898                    TensorStorage::gpu(out_handle),
899                    out_shape.to_vec(),
900                    false,
901                )?;
902
903                // For backward, fall through to CPU path if gradients needed
904                // (GPU backward not yet implemented — stores input for recomputation)
905                if is_grad_enabled()
906                    && (input.requires_grad()
907                        || self.weight.requires_grad()
908                        || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
909                {
910                    // Download cols for backward (CPU backward path).
911                    let input_data = input.data_vec()?;
912                    let (cols, col_rows, col_cols) =
913                        im2col(&input_data, batch, c_in, h, w, kh, kw, sh, sw, ph, pw);
914                    let grad_fn = Arc::new(Conv2dBackward {
915                        input: input.clone(),
916                        weight: self.weight.tensor().clone(),
917                        bias: self.bias.as_ref().map(|b| b.tensor().clone()),
918                        in_channels: self.in_channels,
919                        out_channels: self.out_channels,
920                        kernel_size: self.kernel_size,
921                        stride: self.stride,
922                        padding: self.padding,
923                        dilation: self.dilation,
924                        groups: self.groups,
925                        cols,
926                        col_rows,
927                        col_cols,
928                        h_out,
929                        w_out,
930                    });
931                    return Tensor::from_operation(
932                        result.into_storage_and_shape()?.0,
933                        out_shape.to_vec(),
934                        grad_fn,
935                    );
936                }
937
938                return Ok(result);
939            }
940        }
941
942        // ---- CPU path (handles dense, dilated, grouped, and grouped+dilated) ----
943        let input_data = input.data_vec()?;
944        let weight_data = self.weight.data_vec()?;
945
946        let zero = <T as num_traits::Zero>::zero();
947        let mut output = vec![zero; batch * self.out_channels * h_out * w_out];
948
949        // Per-group dimensions.
950        let in_per_group = self.in_channels / groups;
951        let out_per_group = self.out_channels / groups;
952        let weight_per_group_numel = out_per_group * in_per_group * kh * kw;
953        let group_col_rows = in_per_group * kh * kw;
954        let col_cols = h_out * w_out;
955
956        // Saved im2col columns for autograd (full, ungrouped layout — channel
957        // axis kept dense at C_in so the backward can accumulate grad_input
958        // back into a `[B, C_in, H, W]` tensor uniformly across groups).
959        let saved_cols_rows = self.in_channels * kh * kw;
960        let mut saved_cols: Vec<T> = if is_grad_enabled()
961            && (input.requires_grad()
962                || self.weight.requires_grad()
963                || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
964        {
965            vec![zero; batch * saved_cols_rows * col_cols]
966        } else {
967            Vec::new()
968        };
969        let save_cols = !saved_cols.is_empty();
970
971        for g in 0..groups {
972            // Slice the input channels belonging to this group: [B, in_per_group, H, W].
973            let mut group_input = vec![zero; batch * in_per_group * h * w];
974            for b in 0..batch {
975                for c in 0..in_per_group {
976                    let src_c = g * in_per_group + c;
977                    let src_start = b * self.in_channels * h * w + src_c * h * w;
978                    let dst_start = b * in_per_group * h * w + c * h * w;
979                    group_input[dst_start..dst_start + h * w]
980                        .copy_from_slice(&input_data[src_start..src_start + h * w]);
981                }
982            }
983
984            let (g_cols, g_col_rows, g_col_cols) = im2col_dilated(
985                &group_input,
986                batch,
987                in_per_group,
988                h,
989                w,
990                kh,
991                kw,
992                sh,
993                sw,
994                ph,
995                pw,
996                dh,
997                dw,
998            );
999            debug_assert_eq!(g_col_rows, group_col_rows);
1000            debug_assert_eq!(g_col_cols, col_cols);
1001
1002            // Save into the dense [C_in*kH*kW, col_cols] layout if backward will need it.
1003            if save_cols {
1004                for b in 0..batch {
1005                    for c in 0..in_per_group {
1006                        let dst_c = g * in_per_group + c;
1007                        for kk in 0..(kh * kw) {
1008                            let src_row = c * kh * kw + kk;
1009                            let dst_row = dst_c * kh * kw + kk;
1010                            let src_off = b * group_col_rows * col_cols + src_row * col_cols;
1011                            let dst_off = b * saved_cols_rows * col_cols + dst_row * col_cols;
1012                            saved_cols[dst_off..dst_off + col_cols]
1013                                .copy_from_slice(&g_cols[src_off..src_off + col_cols]);
1014                        }
1015                    }
1016                }
1017            }
1018
1019            // Group's slice of the weight: [out_per_group, in_per_group, kh, kw]
1020            // flattened to [out_per_group, group_col_rows].
1021            let w_group_start = g * weight_per_group_numel;
1022            let w_group_end = w_group_start + weight_per_group_numel;
1023            let weight_group_2d = Tensor::from_storage(
1024                TensorStorage::cpu(weight_data[w_group_start..w_group_end].to_vec()),
1025                vec![out_per_group, group_col_rows],
1026                false,
1027            )?;
1028
1029            for b in 0..batch {
1030                let col_start = b * group_col_rows * col_cols;
1031                let col_end = col_start + group_col_rows * col_cols;
1032                let cols_b = Tensor::from_storage(
1033                    TensorStorage::cpu(g_cols[col_start..col_end].to_vec()),
1034                    vec![group_col_rows, col_cols],
1035                    false,
1036                )?;
1037
1038                let out_b = mm(&weight_group_2d, &cols_b)?;
1039                let out_data = out_b.data()?;
1040                // Place this group's output channels into [b, g*out_per_group..(g+1)*out_per_group, :, :].
1041                for oc in 0..out_per_group {
1042                    let dst_c = g * out_per_group + oc;
1043                    let dst_start = b * self.out_channels * h_out * w_out + dst_c * h_out * w_out;
1044                    let src_start = oc * h_out * w_out;
1045                    output[dst_start..dst_start + h_out * w_out]
1046                        .copy_from_slice(&out_data[src_start..src_start + h_out * w_out]);
1047                }
1048            }
1049        }
1050
1051        // Add bias if present: broadcast [C_out] over [B, C_out, H_out, W_out].
1052        if let Some(ref bias) = self.bias {
1053            let bias_data = bias.data_vec()?;
1054            for b in 0..batch {
1055                for c in 0..self.out_channels {
1056                    let bval = bias_data[c];
1057                    for hw in 0..(h_out * w_out) {
1058                        output[b * self.out_channels * h_out * w_out + c * h_out * w_out + hw] +=
1059                            bval;
1060                    }
1061                }
1062            }
1063        }
1064
1065        let result = Tensor::from_storage(
1066            TensorStorage::cpu(output),
1067            vec![batch, self.out_channels, h_out, w_out],
1068            false,
1069        )?;
1070
1071        // Attach backward if gradients are enabled and any input/param requires grad.
1072        if save_cols {
1073            let grad_fn = Arc::new(Conv2dBackward {
1074                input: input.clone(),
1075                weight: self.weight.tensor().clone(),
1076                bias: self.bias.as_ref().map(|b| b.tensor().clone()),
1077                in_channels: self.in_channels,
1078                out_channels: self.out_channels,
1079                kernel_size: self.kernel_size,
1080                stride: self.stride,
1081                padding: self.padding,
1082                dilation: self.dilation,
1083                groups: self.groups,
1084                cols: saved_cols,
1085                col_rows: saved_cols_rows,
1086                col_cols,
1087                h_out,
1088                w_out,
1089            });
1090            Tensor::from_operation(
1091                TensorStorage::cpu(result.data()?.to_vec()),
1092                result.shape().to_vec(),
1093                grad_fn,
1094            )?
1095            .to(input_device) // restore device
1096        } else {
1097            result.to(input_device)
1098        }
1099    }
1100
1101    fn parameters(&self) -> Vec<&Parameter<T>> {
1102        let mut params = vec![&self.weight];
1103        if let Some(ref b) = self.bias {
1104            params.push(b);
1105        }
1106        params
1107    }
1108
1109    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1110        let mut params = vec![&mut self.weight];
1111        if let Some(ref mut b) = self.bias {
1112            params.push(b);
1113        }
1114        params
1115    }
1116
1117    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1118        let mut params = vec![("weight".to_string(), &self.weight)];
1119        if let Some(ref b) = self.bias {
1120            params.push(("bias".to_string(), b));
1121        }
1122        params
1123    }
1124
1125    fn train(&mut self) {
1126        self.training = true;
1127    }
1128
1129    fn eval(&mut self) {
1130        self.training = false;
1131    }
1132
1133    fn is_training(&self) -> bool {
1134        self.training
1135    }
1136}
1137
1138// ---------------------------------------------------------------------------
1139// Conv2dBackward
1140// ---------------------------------------------------------------------------
1141
1142/// Backward function for `Conv2d` forward pass.
1143///
1144/// Saved tensors:
1145/// - `input`: the original 4-D input
1146/// - `weight`: the 4-D kernel `[C_out, C_in / groups, kH, kW]`
1147/// - `bias`: optional 1-D bias
1148/// - `cols`: the im2col columns from the forward pass with **dense channel
1149///   layout** `[B, C_in * kH * kW, H_out * W_out]`. The forward saves into
1150///   this shape regardless of `groups` so the backward can reuse a uniform
1151///   indexing scheme; for `groups > 1` the per-group slice is taken at
1152///   gradient-computation time.
1153/// - `dilation`, `groups`: extra descriptors needed to reconstruct the
1154///   per-group + dilated math without re-reading them off the layer.
1155#[derive(Debug)]
1156struct Conv2dBackward<T: Float> {
1157    input: Tensor<T>,
1158    weight: Tensor<T>,
1159    bias: Option<Tensor<T>>,
1160    in_channels: usize,
1161    out_channels: usize,
1162    kernel_size: (usize, usize),
1163    stride: (usize, usize),
1164    padding: (usize, usize),
1165    dilation: (usize, usize),
1166    groups: usize,
1167    cols: Vec<T>,
1168    col_rows: usize,
1169    col_cols: usize,
1170    h_out: usize,
1171    w_out: usize,
1172}
1173
1174impl<T: Float> GradFn<T> for Conv2dBackward<T> {
1175    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1176        // grad_output shape: [B, C_out, H_out, W_out].
1177        //
1178        // The backward computation is host-side (im2col / col2im / mm),
1179        // so the produced tensors land on CPU. They must be lifted back
1180        // onto the saved input/weight devices before being returned;
1181        // otherwise downstream gradient ops (e.g. relu_backward, the
1182        // optimizer step) see CPU tensors mixed with CUDA parameters
1183        // and either fall into the `NotImplementedOnCuda` branch or
1184        // fail device assertions in the optimizer. Surfaced by
1185        // `gpu_cnn_training_smoke` in
1186        // `ferrotorch/tests/gpu_training.rs` (#749 Section B).
1187        //
1188        // Note: this is a transitional fix that keeps the chain
1189        // device-consistent while the actual GPU im2col/col2im backward
1190        // kernels are written. A full Conv2dBackward GPU implementation
1191        // is tracked separately (see Section B report).
1192        let input_device = self.input.device();
1193        let weight_device = self.weight.device();
1194        let bias_device = self.bias.as_ref().map(|b| b.device());
1195        let go_data = grad_output.data_vec()?;
1196        let batch = self.input.shape()[0];
1197        let h = self.input.shape()[2];
1198        let w = self.input.shape()[3];
1199        let (kh, kw) = self.kernel_size;
1200        let (sh, sw) = self.stride;
1201        let (ph, pw) = self.padding;
1202        let (dh, dw) = self.dilation;
1203        let groups = self.groups;
1204        let in_per_group = self.in_channels / groups;
1205        let out_per_group = self.out_channels / groups;
1206        let group_col_rows = in_per_group * kh * kw;
1207        let zero = <T as num_traits::Zero>::zero();
1208
1209        // --- grad_weight ---
1210        // Per group `g`:
1211        //   grad_output_b_g : [out_per_group, H_out * W_out]
1212        //   cols_b_g        : [in_per_group * kH * kW, H_out * W_out]
1213        //   gw_g           += grad_output_b_g @ cols_b_g^T
1214        // Stack groups along the C_out axis to recover [C_out, C_in/G, kH, kW].
1215        let grad_weight = if self.weight.requires_grad() {
1216            let weight_numel = self.out_channels * in_per_group * kh * kw;
1217            let mut gw_accum = vec![zero; weight_numel];
1218            let weight_per_group_numel = out_per_group * group_col_rows;
1219
1220            for g in 0..groups {
1221                for b in 0..batch {
1222                    // Slice grad_output for this group: [out_per_group, h_out * w_out].
1223                    let mut go_g = vec![zero; out_per_group * self.h_out * self.w_out];
1224                    for oc in 0..out_per_group {
1225                        let src_c = g * out_per_group + oc;
1226                        let src_start = b * self.out_channels * self.h_out * self.w_out
1227                            + src_c * self.h_out * self.w_out;
1228                        let dst_start = oc * self.h_out * self.w_out;
1229                        go_g[dst_start..dst_start + self.h_out * self.w_out].copy_from_slice(
1230                            &go_data[src_start..src_start + self.h_out * self.w_out],
1231                        );
1232                    }
1233                    let go_b_g = Tensor::from_storage(
1234                        TensorStorage::cpu(go_g),
1235                        vec![out_per_group, self.h_out * self.w_out],
1236                        false,
1237                    )?;
1238
1239                    // Slice cols for this group: [in_per_group * kH * kW, col_cols].
1240                    let mut cols_g = vec![zero; group_col_rows * self.col_cols];
1241                    for c in 0..in_per_group {
1242                        let src_c = g * in_per_group + c;
1243                        for kk in 0..(kh * kw) {
1244                            let src_row = src_c * kh * kw + kk;
1245                            let dst_row = c * kh * kw + kk;
1246                            let src_off =
1247                                b * self.col_rows * self.col_cols + src_row * self.col_cols;
1248                            let dst_off = dst_row * self.col_cols;
1249                            cols_g[dst_off..dst_off + self.col_cols]
1250                                .copy_from_slice(&self.cols[src_off..src_off + self.col_cols]);
1251                        }
1252                    }
1253                    let cols_b_g = Tensor::from_storage(
1254                        TensorStorage::cpu(cols_g),
1255                        vec![group_col_rows, self.col_cols],
1256                        false,
1257                    )?;
1258
1259                    let cols_bt = transpose(&cols_b_g)?;
1260                    let gw_b = mm(&go_b_g, &cols_bt)?;
1261                    let gw_data = gw_b.data()?;
1262
1263                    let dst_off = g * weight_per_group_numel;
1264                    for i in 0..weight_per_group_numel {
1265                        gw_accum[dst_off + i] += gw_data[i];
1266                    }
1267                }
1268            }
1269
1270            Some(
1271                Tensor::from_storage(
1272                    TensorStorage::cpu(gw_accum),
1273                    vec![self.out_channels, in_per_group, kh, kw],
1274                    false,
1275                )?
1276                .to(weight_device)?,
1277            )
1278        } else {
1279            None
1280        };
1281
1282        // --- grad_bias ---
1283        // Sum grad_output over batch, height, width: sum over [B, *, H_out, W_out]
1284        // Result shape: [C_out]. Bias is per-output-channel, identical for any
1285        // groups setting (shape `[C_out]`), so this is unchanged from the dense path.
1286        let grad_bias = match &self.bias {
1287            Some(b) if b.requires_grad() => {
1288                let mut gb = vec![zero; self.out_channels];
1289                for batch_idx in 0..batch {
1290                    for c in 0..self.out_channels {
1291                        for hw in 0..(self.h_out * self.w_out) {
1292                            gb[c] +=
1293                                go_data[batch_idx * self.out_channels * self.h_out * self.w_out
1294                                    + c * self.h_out * self.w_out
1295                                    + hw];
1296                        }
1297                    }
1298                }
1299                let target_dev = bias_device.unwrap_or(input_device);
1300                Some(
1301                    Tensor::from_storage(TensorStorage::cpu(gb), vec![self.out_channels], false)?
1302                        .to(target_dev)?,
1303                )
1304            }
1305            _ => None,
1306        };
1307
1308        // --- grad_input ---
1309        // Per group `g`:
1310        //   weight_g_2d_T @ grad_output_b_g -> [in_per_group * kH * kW, H_out * W_out]
1311        //   then col2im_dilated -> [in_per_group, H, W] -> place into the right
1312        //   in-channel slice of [B, C_in, H, W].
1313        let grad_input = if self.input.requires_grad() {
1314            let weight_data = self.weight.data_vec()?;
1315            let mut grad_input_data = vec![zero; batch * self.in_channels * h * w];
1316            let weight_per_group_numel = out_per_group * group_col_rows;
1317
1318            for g in 0..groups {
1319                let w_off = g * weight_per_group_numel;
1320                let weight_g_2d = Tensor::from_storage(
1321                    TensorStorage::cpu(weight_data[w_off..w_off + weight_per_group_numel].to_vec()),
1322                    vec![out_per_group, group_col_rows],
1323                    false,
1324                )?;
1325                let weight_g_t = transpose(&weight_g_2d)?;
1326
1327                let mut grad_cols_g = vec![zero; batch * group_col_rows * self.col_cols];
1328                for b in 0..batch {
1329                    // Slice grad_output for this group/batch.
1330                    let mut go_g = vec![zero; out_per_group * self.h_out * self.w_out];
1331                    for oc in 0..out_per_group {
1332                        let src_c = g * out_per_group + oc;
1333                        let src_start = b * self.out_channels * self.h_out * self.w_out
1334                            + src_c * self.h_out * self.w_out;
1335                        let dst_start = oc * self.h_out * self.w_out;
1336                        go_g[dst_start..dst_start + self.h_out * self.w_out].copy_from_slice(
1337                            &go_data[src_start..src_start + self.h_out * self.w_out],
1338                        );
1339                    }
1340                    let go_b_g = Tensor::from_storage(
1341                        TensorStorage::cpu(go_g),
1342                        vec![out_per_group, self.h_out * self.w_out],
1343                        false,
1344                    )?;
1345
1346                    let gc_b = mm(&weight_g_t, &go_b_g)?;
1347                    let gc_data = gc_b.data()?;
1348                    let gc_start = b * group_col_rows * self.col_cols;
1349                    grad_cols_g[gc_start..gc_start + group_col_rows * self.col_cols]
1350                        .copy_from_slice(gc_data);
1351                }
1352
1353                // col2im_dilated scatters group's columns back to [B, in_per_group, H, W].
1354                let gi_g = col2im_dilated(
1355                    &grad_cols_g,
1356                    batch,
1357                    in_per_group,
1358                    h,
1359                    w,
1360                    kh,
1361                    kw,
1362                    sh,
1363                    sw,
1364                    ph,
1365                    pw,
1366                    dh,
1367                    dw,
1368                    self.h_out,
1369                    self.w_out,
1370                );
1371
1372                // Place into the corresponding slice of the dense [B, C_in, H, W] tensor.
1373                for b in 0..batch {
1374                    for c in 0..in_per_group {
1375                        let dst_c = g * in_per_group + c;
1376                        let dst_start = b * self.in_channels * h * w + dst_c * h * w;
1377                        let src_start = b * in_per_group * h * w + c * h * w;
1378                        grad_input_data[dst_start..dst_start + h * w]
1379                            .copy_from_slice(&gi_g[src_start..src_start + h * w]);
1380                    }
1381                }
1382            }
1383
1384            Some(
1385                Tensor::from_storage(
1386                    TensorStorage::cpu(grad_input_data),
1387                    self.input.shape().to_vec(),
1388                    false,
1389                )?
1390                .to(input_device)?,
1391            )
1392        } else {
1393            None
1394        };
1395
1396        // Return exactly as many gradients as inputs() returns.
1397        let mut grads = vec![grad_input, grad_weight];
1398        if self.bias.is_some() {
1399            grads.push(grad_bias);
1400        }
1401        Ok(grads)
1402    }
1403
1404    fn inputs(&self) -> Vec<&Tensor<T>> {
1405        let mut v = vec![&self.input, &self.weight];
1406        if let Some(ref b) = self.bias {
1407            v.push(b);
1408        }
1409        v
1410    }
1411
1412    fn name(&self) -> &'static str {
1413        "Conv2dBackward"
1414    }
1415}
1416
1417// ---------------------------------------------------------------------------
1418// Conv1d
1419// ---------------------------------------------------------------------------
1420
1421/// A 1-D convolution layer for sequence data.
1422///
1423/// Applies a temporal convolution over an input `[B, C_in, L]` using
1424/// the im2col + matmul algorithm (delegates to the 2-D helpers with H=1).
1425/// Equivalent to `torch.nn.Conv1d`.
1426///
1427/// # Shape
1428///
1429/// - Input: `[B, in_channels, L]`
1430/// - Output: `[B, out_channels, L_out]`
1431///
1432/// where `L_out = (L + 2 * padding - kernel_size) / stride + 1`.
1433#[derive(Debug)]
1434pub struct Conv1d<T: Float> {
1435    /// Learnable kernel weights `[out_channels, in_channels / groups, kernel_size]`.
1436    weight: Parameter<T>,
1437    /// Optional learnable bias `[out_channels]`.
1438    bias: Option<Parameter<T>>,
1439    /// Number of input channels.
1440    in_channels: usize,
1441    /// Number of output channels (filters).
1442    out_channels: usize,
1443    /// Kernel length.
1444    kernel_size: usize,
1445    /// Stride.
1446    stride: usize,
1447    /// Zero-padding applied to both sides.
1448    padding: usize,
1449    /// Dilation. `1` is the dense default. Spaces kernel taps `dilation`
1450    /// apart along the temporal axis (`eff_kernel = dilation * (k - 1) + 1`),
1451    /// mirroring `torch.nn.Conv1d(..., dilation=1)` (`conv.py:337`).
1452    dilation: usize,
1453    /// Number of blocked input/output channel groups. `1` is dense,
1454    /// `in_channels` is depthwise. Must divide both `in_channels` and
1455    /// `out_channels`, mirroring `torch.nn.Conv1d(..., groups=1)`
1456    /// (`conv.py:338`, validation `conv.py:107-110`).
1457    groups: usize,
1458    /// Boundary handling for the spatial padding. `Zeros` (default) routes
1459    /// through the existing im2col zero-pad path; non-`Zeros` modes pre-pad
1460    /// the input via `crate::padding::functional_pad_1d` and then run the
1461    /// dense im2col over the already-padded tensor (matching the upstream
1462    /// `_ConvNd._conv_forward` for Conv1d: `F.pad(input, ..., mode=...)` first,
1463    /// then a `padding=0` convolution). See `torch/nn/modules/conv.py:367-378`.
1464    /// Closes #1443.
1465    padding_mode: crate::padding::PaddingMode,
1466    /// String padding mode (`'same'` / `'valid'`), `None` when numeric
1467    /// `padding` is used. When `Some`, the `padding` field is ignored and the
1468    /// effective padding is derived per [`StringPadding`] in `forward`
1469    /// (mirroring the `padding: str` branch of `torch.nn.Conv1d`,
1470    /// `torch/nn/modules/conv.py:111-155`). Set via
1471    /// [`Conv1d::with_string_padding`]. Closes #1602.
1472    string_padding: Option<StringPadding>,
1473    /// Whether the module is in training mode.
1474    training: bool,
1475}
1476
1477impl<T: Float> Conv1d<T> {
1478    /// Create a new `Conv1d` layer (dense, dilation `1`, `groups = 1`).
1479    ///
1480    /// Weight is initialized with Kaiming uniform (ReLU gain).
1481    /// Bias, if enabled, is initialized U(-bound, bound) with
1482    /// `bound = 1/sqrt(fan_in)` per `torch/nn/modules/conv.py:198-201`.
1483    ///
1484    /// This is a thin shim over [`Conv1d::new_full`] preserved for callers
1485    /// that only need the dense configuration (e.g. `LazyConv1d::materialize`).
1486    pub fn new(
1487        in_channels: usize,
1488        out_channels: usize,
1489        kernel_size: usize,
1490        stride: usize,
1491        padding: usize,
1492        bias: bool,
1493    ) -> FerrotorchResult<Self> {
1494        Self::new_full(
1495            in_channels,
1496            out_channels,
1497            kernel_size,
1498            stride,
1499            padding,
1500            1,
1501            1,
1502            bias,
1503        )
1504    }
1505
1506    /// Create a new `Conv1d` layer with the full PyTorch-shaped argument set,
1507    /// including `dilation` and `groups`.
1508    ///
1509    /// `groups` must divide BOTH `in_channels` and `out_channels` (PyTorch
1510    /// `torch.nn.Conv1d` raises `ValueError` otherwise, `conv.py:107-110`).
1511    /// `dilation` must be strictly positive. Weight is initialised with
1512    /// Kaiming uniform (ReLU gain); bias (if enabled) with U(-bound, bound)
1513    /// where `bound = 1/sqrt(fan_in)`, `fan_in = (in_channels/groups) *
1514    /// kernel_size` per `torch/nn/modules/conv.py:198-201`.
1515    ///
1516    /// Weight layout is `[out_channels, in_channels / groups, kernel_size]`,
1517    /// the PyTorch grouped-conv convention (`conv.py:171`). Argument order
1518    /// `(.., dilation, groups, bias)` mirrors `Conv1d.__init__`
1519    /// (`conv.py:330-339`, R-DEV-2).
1520    #[allow(clippy::too_many_arguments)]
1521    pub fn new_full(
1522        in_channels: usize,
1523        out_channels: usize,
1524        kernel_size: usize,
1525        stride: usize,
1526        padding: usize,
1527        dilation: usize,
1528        groups: usize,
1529        bias: bool,
1530    ) -> FerrotorchResult<Self> {
1531        if in_channels == 0 || out_channels == 0 {
1532            return Err(FerrotorchError::InvalidArgument {
1533                message: "in_channels and out_channels must be > 0".into(),
1534            });
1535        }
1536        if kernel_size == 0 {
1537            return Err(FerrotorchError::InvalidArgument {
1538                message: "kernel_size must be > 0".into(),
1539            });
1540        }
1541        if stride == 0 {
1542            return Err(FerrotorchError::InvalidArgument {
1543                message: "stride must be > 0".into(),
1544            });
1545        }
1546        if dilation == 0 {
1547            return Err(FerrotorchError::InvalidArgument {
1548                message: format!("Conv1d::new_full: dilation must be > 0, got {dilation}"),
1549            });
1550        }
1551        if groups == 0 {
1552            return Err(FerrotorchError::InvalidArgument {
1553                message: "Conv1d::new_full: groups must be > 0".into(),
1554            });
1555        }
1556        // `torch/nn/modules/conv.py:107-110`: `in_channels % groups != 0`
1557        // and `out_channels % groups != 0` each raise ValueError.
1558        if in_channels % groups != 0 {
1559            return Err(FerrotorchError::InvalidArgument {
1560                message: format!(
1561                    "Conv1d::new_full: groups={groups} must divide in_channels={in_channels}"
1562                ),
1563            });
1564        }
1565        if out_channels % groups != 0 {
1566            return Err(FerrotorchError::InvalidArgument {
1567                message: format!(
1568                    "Conv1d::new_full: groups={groups} must divide out_channels={out_channels}"
1569                ),
1570            });
1571        }
1572
1573        // PyTorch weight layout is [C_out, C_in / groups, k] (`conv.py:171`).
1574        let mut weight = Parameter::zeros(&[out_channels, in_channels / groups, kernel_size])?;
1575        kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
1576
1577        let bias_param = if bias {
1578            let mut b = Parameter::zeros(&[out_channels])?;
1579            // `torch/nn/modules/conv.py:198-201`: bias U(-bound, bound) with
1580            //   `bound = 1 / sqrt(fan_in)`, `fan_in = (in_channels/groups) * kernel_size`.
1581            let fan_in = (in_channels / groups) * kernel_size;
1582            let bound = if fan_in > 0 {
1583                1.0 / (fan_in as f64).sqrt()
1584            } else {
1585                0.0
1586            };
1587            uniform_init(&mut b, -bound, bound)?;
1588            Some(b)
1589        } else {
1590            None
1591        };
1592
1593        Ok(Self {
1594            weight,
1595            bias: bias_param,
1596            in_channels,
1597            out_channels,
1598            kernel_size,
1599            stride,
1600            padding,
1601            dilation,
1602            groups,
1603            padding_mode: crate::padding::PaddingMode::Zeros,
1604            string_padding: None,
1605            training: true,
1606        })
1607    }
1608
1609    /// Number of channel groups (`1` is dense, `in_channels` is depthwise).
1610    pub fn groups(&self) -> usize {
1611        self.groups
1612    }
1613
1614    /// Dilation (`1` is the dense default).
1615    pub fn dilation(&self) -> usize {
1616        self.dilation
1617    }
1618
1619    /// Configure string padding (`'same'` / `'valid'`), mirroring the
1620    /// `padding: str` branch of `torch.nn.Conv1d` (`conv.py:111-155`).
1621    ///
1622    /// `StringPadding::Valid` is equivalent to `padding = 0`.
1623    /// `StringPadding::Same` pads so the output length equals the input length
1624    /// (for `stride = 1`), splitting the total `dilation * (kernel - 1)`
1625    /// asymmetrically as `left = total/2`, `right = total - left` (the END
1626    /// gets the extra unit; see [`same_pad_lr`]). The pre-pad uses the
1627    /// configured `padding_mode` (constant-0 for the default `Zeros`, matching
1628    /// `convolution_same`'s `constant_pad_nd(.., 0)`, `Convolution.cpp:1105`)
1629    /// and is autograd-aware via `Pad1dBackward`.
1630    ///
1631    /// Returns `Err` if `StringPadding::Same` is requested with `stride != 1`,
1632    /// matching upstream `raise ValueError("padding='same' is not supported
1633    /// for strided convolutions")` (`conv.py:117-120`,
1634    /// `Convolution.cpp:1071`). Closes #1602.
1635    pub fn with_string_padding(mut self, padding: StringPadding) -> FerrotorchResult<Self> {
1636        if padding == StringPadding::Same && self.stride != 1 {
1637            return Err(FerrotorchError::InvalidArgument {
1638                message: "padding='same' is not supported for strided convolutions".into(),
1639            });
1640        }
1641        self.string_padding = Some(padding);
1642        self.padding = 0;
1643        Ok(self)
1644    }
1645
1646    /// Configure the boundary handling for the spatial padding.
1647    ///
1648    /// `Zeros` (default) uses the existing im2col zero-pad path.
1649    /// `Reflect`, `Replicate`, and `Circular` pre-pad the input via
1650    /// `crate::padding::functional_pad_1d(input, ...)` and then convolve
1651    /// with `padding = 0`, matching `torch.nn.Conv1d(..., padding_mode=...)`
1652    /// (`_ConvNd._conv_forward`'s `F.pad` shape, `conv.py:367-378`). The pad
1653    /// is autograd-aware (`Pad1dBackward`), so input gradients flow through
1654    /// the boundary. Closes #1443.
1655    pub fn with_padding_mode(mut self, mode: crate::padding::PaddingMode) -> Self {
1656        self.padding_mode = mode;
1657        self
1658    }
1659
1660    /// The number of learnable scalar parameters.
1661    pub fn num_parameters(&self) -> usize {
1662        let w = self.out_channels * self.in_channels * self.kernel_size;
1663        let b = if self.bias.is_some() {
1664            self.out_channels
1665        } else {
1666            0
1667        };
1668        w + b
1669    }
1670
1671    /// Build a `Conv1d` from caller-supplied weight and optional bias tensors.
1672    ///
1673    /// `weight` must have shape `[out_channels, in_channels, kernel_size]`.
1674    /// The resulting layer is dense (`groups = 1`, `dilation = 1`) so the
1675    /// constructor remains API-compatible with `nn::functional::conv1d`,
1676    /// which infers `in_channels = weight.shape()[1]` and cannot recover
1677    /// `groups` from the weight shape alone.
1678    pub fn from_parts(
1679        weight: Tensor<T>,
1680        bias: Option<Tensor<T>>,
1681        stride: usize,
1682        padding: usize,
1683    ) -> FerrotorchResult<Self> {
1684        if weight.ndim() != 3 {
1685            return Err(FerrotorchError::ShapeMismatch {
1686                message: format!(
1687                    "Conv1d::from_parts: weight must be 3-D [out, in, k], got {:?}",
1688                    weight.shape()
1689                ),
1690            });
1691        }
1692        let out_channels = weight.shape()[0];
1693        let in_channels = weight.shape()[1];
1694        let kernel_size = weight.shape()[2];
1695        if let Some(b) = &bias {
1696            if b.ndim() != 1 || b.shape()[0] != out_channels {
1697                return Err(FerrotorchError::ShapeMismatch {
1698                    message: format!(
1699                        "Conv1d::from_parts: bias shape {:?} != [{}]",
1700                        b.shape(),
1701                        out_channels
1702                    ),
1703                });
1704            }
1705        }
1706        Ok(Self {
1707            weight: Parameter::new(weight),
1708            bias: bias.map(Parameter::new),
1709            in_channels,
1710            out_channels,
1711            kernel_size,
1712            stride,
1713            padding,
1714            dilation: 1,
1715            groups: 1,
1716            padding_mode: crate::padding::PaddingMode::Zeros,
1717            string_padding: None,
1718            training: true,
1719        })
1720    }
1721
1722    /// Build a shallow clone with the geometry overridden (used by `forward`
1723    /// to recurse onto the dense zero-padding im2col path after a
1724    /// string-padding / non-zero `padding_mode` pre-pad). `string_padding` is
1725    /// cleared so the recursion runs the numeric-padding path.
1726    fn recurse_clone(
1727        &self,
1728        padding: usize,
1729        padding_mode: crate::padding::PaddingMode,
1730    ) -> Conv1d<T> {
1731        Conv1d {
1732            weight: Parameter::new(self.weight.tensor().clone()),
1733            bias: self
1734                .bias
1735                .as_ref()
1736                .map(|b| Parameter::new(b.tensor().clone())),
1737            in_channels: self.in_channels,
1738            out_channels: self.out_channels,
1739            kernel_size: self.kernel_size,
1740            stride: self.stride,
1741            padding,
1742            dilation: self.dilation,
1743            groups: self.groups,
1744            padding_mode,
1745            string_padding: None,
1746            training: self.training,
1747        }
1748    }
1749}
1750
1751impl<T: Float> Module<T> for Conv1d<T> {
1752    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1753        // Record autocast decision for conv1d.
1754        let _autocast_cat = autocast_guard("conv1d");
1755
1756        // Unbatched input: `(C, L)` (rank 2) is accepted in addition to the
1757        // batched `(N, C, L)` (rank 3) form. Mirrors `batchify` /
1758        // `_conv_forward` at `aten/src/ATen/native/Convolution.cpp:816-831,
1759        // 990-997`: an unbatched input is `unsqueeze(0)`d, convolved, then
1760        // `squeeze(0)`d so the output is rank 2. The unsqueeze/squeeze are
1761        // autograd-aware so gradients flow back to the unbatched shape. Closes
1762        // #1604.
1763        if input.ndim() == 2 {
1764            let batched = unsqueeze(input, 0)?;
1765            let output = self.forward(&batched)?;
1766            return squeeze(&output, 0);
1767        }
1768
1769        // String padding ('same' / 'valid'), mirroring the `padding: str`
1770        // branch of `torch.nn.Conv1d` (`conv.py:111-155`,
1771        // `Convolution.cpp:1119-1124`). `Valid` == numeric `padding = 0`;
1772        // `Same` pre-pads asymmetrically (`left = total/2`, `right = total -
1773        // left`) via the autograd-aware `functional_pad_1d` then convolves
1774        // with `padding = 0` — the `convolution_same` ->
1775        // `constant_pad_nd(.., 0)` path (`Convolution.cpp:1098-1108`). The
1776        // stride>1 rejection happened at `with_string_padding` construction
1777        // (`conv.py:117-120`). Closes #1602.
1778        if let Some(sp) = self.string_padding {
1779            match sp {
1780                StringPadding::Valid => {
1781                    return self
1782                        .recurse_clone(0, crate::padding::PaddingMode::Zeros)
1783                        .forward(input);
1784                }
1785                StringPadding::Same => {
1786                    let (left, right) = same_pad_lr(self.kernel_size, self.dilation);
1787                    let padded = crate::padding::functional_pad_1d(
1788                        input,
1789                        left,
1790                        right,
1791                        self.padding_mode,
1792                        <T as num_traits::Zero>::zero(),
1793                    )?;
1794                    return self
1795                        .recurse_clone(0, crate::padding::PaddingMode::Zeros)
1796                        .forward(&padded);
1797                }
1798            }
1799        }
1800
1801        // Non-zero padding modes: pre-pad the input with the requested
1802        // boundary mode and then convolve with padding = 0. Mirrors
1803        // `torch/nn/modules/conv.py` `Conv1d._conv_forward` (`conv.py:367-378`):
1804        //   if self.padding_mode != 'zeros':
1805        //       F.conv1d(F.pad(input, self._reversed_padding_repeated_twice,
1806        //                      mode=self.padding_mode), ..., padding=_single(0), ...)
1807        // For an int `padding=p`, `_reversed_padding_repeated_twice` is `[p, p]`
1808        // (`conv.py:157` `_reverse_repeat_tuple(self.padding, 2)`), i.e. a
1809        // symmetric `(pad_left, pad_right) = (p, p)`. The pre-pad is
1810        // autograd-aware (`Pad1dBackward`) so input gradients flow through the
1811        // boundary. Closes #1443.
1812        if self.padding_mode != crate::padding::PaddingMode::Zeros && self.padding != 0 {
1813            let padded = crate::padding::functional_pad_1d(
1814                input,
1815                self.padding,
1816                self.padding,
1817                self.padding_mode,
1818                <T as num_traits::Zero>::zero(),
1819            )?;
1820            // Recurse on a zero-padding variant: build a shallow clone with
1821            // padding = 0 and padding_mode = Zeros so the existing
1822            // im2col-on-zero-pad path runs without re-padding.
1823            return self
1824                .recurse_clone(0, crate::padding::PaddingMode::Zeros)
1825                .forward(&padded);
1826        }
1827
1828        // Validate input shape: [B, C_in, L].
1829        if input.ndim() != 3 {
1830            return Err(FerrotorchError::InvalidArgument {
1831                message: format!(
1832                    "Conv1d expects 3-D input [B, C, L], got {:?}",
1833                    input.shape()
1834                ),
1835            });
1836        }
1837
1838        let batch = input.shape()[0];
1839        let c_in = input.shape()[1];
1840        let length = input.shape()[2];
1841
1842        if c_in != self.in_channels {
1843            return Err(FerrotorchError::ShapeMismatch {
1844                message: format!(
1845                    "Conv1d: expected {} input channels, got {}",
1846                    self.in_channels, c_in
1847                ),
1848            });
1849        }
1850
1851        let k = self.kernel_size;
1852        let s = self.stride;
1853        let p = self.padding;
1854        let dil = self.dilation;
1855        let groups = self.groups;
1856
1857        // Effective kernel extent after dilation, mirroring
1858        // `ConvUtils.h:255` `kernel = dilation * (weight_size - 1) + 1`.
1859        let eff_k = dil * (k - 1) + 1;
1860        let l_padded = length + 2 * p;
1861        if l_padded < eff_k {
1862            return Err(FerrotorchError::InvalidArgument {
1863                message: format!(
1864                    "Conv1d: padded input length ({l_padded}) is smaller than effective kernel ({eff_k})"
1865                ),
1866            });
1867        }
1868
1869        let l_out = (l_padded - eff_k) / s + 1;
1870
1871        // Save the input device so we can restore it on the output.
1872        let input_device = input.device();
1873
1874        // Reshape input [B, C_in, L] -> [B, C_in, 1, L] and use the 2-D dilated
1875        // im2col with kernel (1, k), stride (1, s), padding (0, p), dilation
1876        // (1, dil) so the temporal dilation maps to the W axis. The CPU path
1877        // partitions channels by `groups` exactly like Conv2d: each group's
1878        // input slice [B, in_per_group, L] is convolved with its weight slice
1879        // and the outputs are stacked along the C_out axis (mirroring the
1880        // per-group subtensor/cat loop at `Convolution.cpp:1723-1729`).
1881        let input_data = input.data_vec()?;
1882        let weight_data = self.weight.data_vec()?;
1883
1884        let zero = <T as num_traits::Zero>::zero();
1885        let mut output = vec![zero; batch * self.out_channels * l_out];
1886
1887        // Per-group dimensions.
1888        let in_per_group = self.in_channels / groups;
1889        let out_per_group = self.out_channels / groups;
1890        let weight_per_group_numel = out_per_group * in_per_group * k;
1891        let group_col_rows = in_per_group * k;
1892        let col_cols = l_out;
1893
1894        // Saved im2col columns for autograd (dense channel layout `[B,
1895        // C_in * k, L_out]` so the backward can accumulate grad_input back
1896        // into a `[B, C_in, L]` tensor uniformly across groups, exactly like
1897        // Conv2dBackward).
1898        let saved_cols_rows = self.in_channels * k;
1899        let mut saved_cols: Vec<T> = if is_grad_enabled()
1900            && (input.requires_grad()
1901                || self.weight.requires_grad()
1902                || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
1903        {
1904            vec![zero; batch * saved_cols_rows * col_cols]
1905        } else {
1906            Vec::new()
1907        };
1908        let save_cols = !saved_cols.is_empty();
1909
1910        for g in 0..groups {
1911            // Slice the input channels belonging to this group: [B, in_per_group, L].
1912            let mut group_input = vec![zero; batch * in_per_group * length];
1913            for b in 0..batch {
1914                for c in 0..in_per_group {
1915                    let src_c = g * in_per_group + c;
1916                    let src_start = b * self.in_channels * length + src_c * length;
1917                    let dst_start = b * in_per_group * length + c * length;
1918                    group_input[dst_start..dst_start + length]
1919                        .copy_from_slice(&input_data[src_start..src_start + length]);
1920                }
1921            }
1922
1923            let (g_cols, g_col_rows, g_col_cols) = im2col_dilated(
1924                &group_input,
1925                batch,
1926                in_per_group,
1927                1,
1928                length,
1929                1,
1930                k,
1931                1,
1932                s,
1933                0,
1934                p,
1935                1,
1936                dil,
1937            );
1938            debug_assert_eq!(g_col_rows, group_col_rows);
1939            debug_assert_eq!(g_col_cols, col_cols);
1940
1941            // Save into the dense [C_in * k, col_cols] layout if backward needs it.
1942            if save_cols {
1943                for b in 0..batch {
1944                    for c in 0..in_per_group {
1945                        let dst_c = g * in_per_group + c;
1946                        for kk in 0..k {
1947                            let src_row = c * k + kk;
1948                            let dst_row = dst_c * k + kk;
1949                            let src_off = b * group_col_rows * col_cols + src_row * col_cols;
1950                            let dst_off = b * saved_cols_rows * col_cols + dst_row * col_cols;
1951                            saved_cols[dst_off..dst_off + col_cols]
1952                                .copy_from_slice(&g_cols[src_off..src_off + col_cols]);
1953                        }
1954                    }
1955                }
1956            }
1957
1958            // Group's slice of the weight: [out_per_group, in_per_group, k]
1959            // flattened to [out_per_group, group_col_rows].
1960            let w_group_start = g * weight_per_group_numel;
1961            let w_group_end = w_group_start + weight_per_group_numel;
1962            let weight_group_2d = Tensor::from_storage(
1963                TensorStorage::cpu(weight_data[w_group_start..w_group_end].to_vec()),
1964                vec![out_per_group, group_col_rows],
1965                false,
1966            )?;
1967
1968            for b in 0..batch {
1969                let col_start = b * group_col_rows * col_cols;
1970                let col_end = col_start + group_col_rows * col_cols;
1971                let cols_b = Tensor::from_storage(
1972                    TensorStorage::cpu(g_cols[col_start..col_end].to_vec()),
1973                    vec![group_col_rows, col_cols],
1974                    false,
1975                )?;
1976
1977                let out_b = mm(&weight_group_2d, &cols_b)?;
1978                let out_data = out_b.data()?;
1979                // Place this group's output channels into [b, g*out_per_group.., :].
1980                for oc in 0..out_per_group {
1981                    let dst_c = g * out_per_group + oc;
1982                    let dst_start = b * self.out_channels * l_out + dst_c * l_out;
1983                    let src_start = oc * l_out;
1984                    output[dst_start..dst_start + l_out]
1985                        .copy_from_slice(&out_data[src_start..src_start + l_out]);
1986                }
1987            }
1988        }
1989
1990        // Add bias if present: broadcast [C_out] over [B, C_out, L_out].
1991        if let Some(ref bias) = self.bias {
1992            let bias_data = bias.data_vec()?;
1993            for b in 0..batch {
1994                for c in 0..self.out_channels {
1995                    let bval = bias_data[c];
1996                    for l in 0..l_out {
1997                        output[b * self.out_channels * l_out + c * l_out + l] += bval;
1998                    }
1999                }
2000            }
2001        }
2002
2003        let result = Tensor::from_storage(
2004            TensorStorage::cpu(output),
2005            vec![batch, self.out_channels, l_out],
2006            false,
2007        )?;
2008
2009        // Attach backward if gradients are enabled.
2010        if save_cols {
2011            let grad_fn = Arc::new(Conv1dBackward {
2012                input: input.clone(),
2013                weight: self.weight.tensor().clone(),
2014                bias: self.bias.as_ref().map(|b| b.tensor().clone()),
2015                in_channels: self.in_channels,
2016                out_channels: self.out_channels,
2017                kernel_size: self.kernel_size,
2018                stride: self.stride,
2019                padding: self.padding,
2020                dilation: self.dilation,
2021                groups: self.groups,
2022                cols: saved_cols,
2023                col_rows: saved_cols_rows,
2024                col_cols,
2025                l_out,
2026            });
2027            Tensor::from_operation(
2028                TensorStorage::cpu(result.data()?.to_vec()),
2029                result.shape().to_vec(),
2030                grad_fn,
2031            )?
2032            .to(input_device) // restore device
2033        } else {
2034            result.to(input_device)
2035        }
2036    }
2037
2038    fn parameters(&self) -> Vec<&Parameter<T>> {
2039        let mut params = vec![&self.weight];
2040        if let Some(ref b) = self.bias {
2041            params.push(b);
2042        }
2043        params
2044    }
2045
2046    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2047        let mut params = vec![&mut self.weight];
2048        if let Some(ref mut b) = self.bias {
2049            params.push(b);
2050        }
2051        params
2052    }
2053
2054    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2055        let mut params = vec![("weight".to_string(), &self.weight)];
2056        if let Some(ref b) = self.bias {
2057            params.push(("bias".to_string(), b));
2058        }
2059        params
2060    }
2061
2062    fn train(&mut self) {
2063        self.training = true;
2064    }
2065
2066    fn eval(&mut self) {
2067        self.training = false;
2068    }
2069
2070    fn is_training(&self) -> bool {
2071        self.training
2072    }
2073}
2074
2075// ---------------------------------------------------------------------------
2076// Conv1dBackward
2077// ---------------------------------------------------------------------------
2078
2079/// Backward function for `Conv1d` forward pass.
2080///
2081/// Saved `cols` use the **dense channel layout** `[B, C_in * k, L_out]`
2082/// (the forward saves into this shape regardless of `groups`), mirroring
2083/// `Conv2dBackward`'s grouped scheme so the per-group slice is taken at
2084/// gradient-computation time and grad_input accumulates uniformly across
2085/// groups. `dilation`/`groups` reconstruct the per-group + dilated math.
2086#[derive(Debug)]
2087struct Conv1dBackward<T: Float> {
2088    input: Tensor<T>,
2089    weight: Tensor<T>,
2090    bias: Option<Tensor<T>>,
2091    in_channels: usize,
2092    out_channels: usize,
2093    kernel_size: usize,
2094    stride: usize,
2095    padding: usize,
2096    dilation: usize,
2097    groups: usize,
2098    cols: Vec<T>,
2099    col_rows: usize,
2100    col_cols: usize,
2101    l_out: usize,
2102}
2103
2104impl<T: Float> GradFn<T> for Conv1dBackward<T> {
2105    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2106        // grad_output shape: [B, C_out, L_out]
2107        let input_device = self.input.device();
2108        let weight_device = self.weight.device();
2109        let bias_device = self.bias.as_ref().map(|b| b.device());
2110        let go_data = grad_output.data_vec()?;
2111        let batch = self.input.shape()[0];
2112        let length = self.input.shape()[2];
2113        let k = self.kernel_size;
2114        let s = self.stride;
2115        let p = self.padding;
2116        let dil = self.dilation;
2117        let groups = self.groups;
2118        let in_per_group = self.in_channels / groups;
2119        let out_per_group = self.out_channels / groups;
2120        let group_col_rows = in_per_group * k;
2121        let zero = <T as num_traits::Zero>::zero();
2122
2123        // --- grad_weight ---
2124        // Per group `g`: gw_g += grad_output_b_g @ cols_b_g^T, stacked along
2125        // the C_out axis to recover [C_out, C_in/G, k]. Mirrors Conv2dBackward.
2126        let grad_weight = if self.weight.requires_grad() {
2127            let weight_numel = self.out_channels * in_per_group * k;
2128            let mut gw_accum = vec![zero; weight_numel];
2129            let weight_per_group_numel = out_per_group * group_col_rows;
2130
2131            for g in 0..groups {
2132                for b in 0..batch {
2133                    // Slice grad_output for this group: [out_per_group, l_out].
2134                    let mut go_g = vec![zero; out_per_group * self.l_out];
2135                    for oc in 0..out_per_group {
2136                        let src_c = g * out_per_group + oc;
2137                        let src_start = b * self.out_channels * self.l_out + src_c * self.l_out;
2138                        let dst_start = oc * self.l_out;
2139                        go_g[dst_start..dst_start + self.l_out]
2140                            .copy_from_slice(&go_data[src_start..src_start + self.l_out]);
2141                    }
2142                    let go_b_g = Tensor::from_storage(
2143                        TensorStorage::cpu(go_g),
2144                        vec![out_per_group, self.l_out],
2145                        false,
2146                    )?;
2147
2148                    // Slice cols for this group: [in_per_group * k, col_cols].
2149                    let mut cols_g = vec![zero; group_col_rows * self.col_cols];
2150                    for c in 0..in_per_group {
2151                        let src_c = g * in_per_group + c;
2152                        for kk in 0..k {
2153                            let src_row = src_c * k + kk;
2154                            let dst_row = c * k + kk;
2155                            let src_off =
2156                                b * self.col_rows * self.col_cols + src_row * self.col_cols;
2157                            let dst_off = dst_row * self.col_cols;
2158                            cols_g[dst_off..dst_off + self.col_cols]
2159                                .copy_from_slice(&self.cols[src_off..src_off + self.col_cols]);
2160                        }
2161                    }
2162                    let cols_b_g = Tensor::from_storage(
2163                        TensorStorage::cpu(cols_g),
2164                        vec![group_col_rows, self.col_cols],
2165                        false,
2166                    )?;
2167
2168                    let cols_bt = transpose(&cols_b_g)?;
2169                    let gw_b = mm(&go_b_g, &cols_bt)?;
2170                    let gw_data = gw_b.data()?;
2171
2172                    let dst_off = g * weight_per_group_numel;
2173                    for i in 0..weight_per_group_numel {
2174                        gw_accum[dst_off + i] += gw_data[i];
2175                    }
2176                }
2177            }
2178
2179            Some(
2180                Tensor::from_storage(
2181                    TensorStorage::cpu(gw_accum),
2182                    vec![self.out_channels, in_per_group, k],
2183                    false,
2184                )?
2185                .to(weight_device)?,
2186            )
2187        } else {
2188            None
2189        };
2190
2191        // --- grad_bias ---
2192        // Sum grad_output over batch + length. Bias is per-output-channel
2193        // ([C_out]), identical for any groups setting.
2194        let grad_bias = match &self.bias {
2195            Some(b) if b.requires_grad() => {
2196                let mut gb = vec![zero; self.out_channels];
2197                for batch_idx in 0..batch {
2198                    for c in 0..self.out_channels {
2199                        for l in 0..self.l_out {
2200                            gb[c] += go_data
2201                                [batch_idx * self.out_channels * self.l_out + c * self.l_out + l];
2202                        }
2203                    }
2204                }
2205                let target_dev = bias_device.unwrap_or(input_device);
2206                Some(
2207                    Tensor::from_storage(TensorStorage::cpu(gb), vec![self.out_channels], false)?
2208                        .to(target_dev)?,
2209                )
2210            }
2211            _ => None,
2212        };
2213
2214        // --- grad_input ---
2215        // Per group `g`: weight_g_2d^T @ grad_output_b_g -> [in_per_group * k,
2216        // l_out], then col2im_dilated -> [in_per_group, 1, L] placed into the
2217        // right in-channel slice of [B, C_in, L]. Mirrors Conv2dBackward.
2218        let grad_input = if self.input.requires_grad() {
2219            let weight_data = self.weight.data_vec()?;
2220            let mut grad_input_data = vec![zero; batch * self.in_channels * length];
2221            let weight_per_group_numel = out_per_group * group_col_rows;
2222
2223            for g in 0..groups {
2224                let w_off = g * weight_per_group_numel;
2225                let weight_g_2d = Tensor::from_storage(
2226                    TensorStorage::cpu(weight_data[w_off..w_off + weight_per_group_numel].to_vec()),
2227                    vec![out_per_group, group_col_rows],
2228                    false,
2229                )?;
2230                let weight_g_t = transpose(&weight_g_2d)?;
2231
2232                let mut grad_cols_g = vec![zero; batch * group_col_rows * self.col_cols];
2233                for b in 0..batch {
2234                    let mut go_g = vec![zero; out_per_group * self.l_out];
2235                    for oc in 0..out_per_group {
2236                        let src_c = g * out_per_group + oc;
2237                        let src_start = b * self.out_channels * self.l_out + src_c * self.l_out;
2238                        let dst_start = oc * self.l_out;
2239                        go_g[dst_start..dst_start + self.l_out]
2240                            .copy_from_slice(&go_data[src_start..src_start + self.l_out]);
2241                    }
2242                    let go_b_g = Tensor::from_storage(
2243                        TensorStorage::cpu(go_g),
2244                        vec![out_per_group, self.l_out],
2245                        false,
2246                    )?;
2247
2248                    let gc_b = mm(&weight_g_t, &go_b_g)?;
2249                    let gc_data = gc_b.data()?;
2250                    let gc_start = b * group_col_rows * self.col_cols;
2251                    grad_cols_g[gc_start..gc_start + group_col_rows * self.col_cols]
2252                        .copy_from_slice(gc_data);
2253                }
2254
2255                // col2im_dilated scatters group's columns back to
2256                // [B, in_per_group, 1, L]; the W axis carries the dilation.
2257                let gi_g = col2im_dilated(
2258                    &grad_cols_g,
2259                    batch,
2260                    in_per_group,
2261                    1,
2262                    length,
2263                    1,
2264                    k,
2265                    1,
2266                    s,
2267                    0,
2268                    p,
2269                    1,
2270                    dil,
2271                    1,
2272                    self.l_out,
2273                );
2274
2275                for b in 0..batch {
2276                    for c in 0..in_per_group {
2277                        let dst_c = g * in_per_group + c;
2278                        let dst_start = b * self.in_channels * length + dst_c * length;
2279                        let src_start = b * in_per_group * length + c * length;
2280                        grad_input_data[dst_start..dst_start + length]
2281                            .copy_from_slice(&gi_g[src_start..src_start + length]);
2282                    }
2283                }
2284            }
2285
2286            Some(
2287                Tensor::from_storage(
2288                    TensorStorage::cpu(grad_input_data),
2289                    self.input.shape().to_vec(),
2290                    false,
2291                )?
2292                .to(input_device)?,
2293            )
2294        } else {
2295            None
2296        };
2297
2298        let mut grads = vec![grad_input, grad_weight];
2299        if self.bias.is_some() {
2300            grads.push(grad_bias);
2301        }
2302        Ok(grads)
2303    }
2304
2305    fn inputs(&self) -> Vec<&Tensor<T>> {
2306        let mut v = vec![&self.input, &self.weight];
2307        if let Some(ref b) = self.bias {
2308            v.push(b);
2309        }
2310        v
2311    }
2312
2313    fn name(&self) -> &'static str {
2314        "Conv1dBackward"
2315    }
2316}
2317
2318// ---------------------------------------------------------------------------
2319// ConvTranspose2d
2320// ---------------------------------------------------------------------------
2321
2322/// A 2-D transposed convolution (deconvolution) layer.
2323///
2324/// Applies a transposed spatial convolution over an input `[B, C_in, H, W]`.
2325/// Used for upsampling in generative models and decoder networks.
2326/// Equivalent to `torch.nn.ConvTranspose2d`.
2327///
2328/// # Implementation
2329///
2330/// The forward pass inserts `(stride - 1)` zeros between each input element
2331/// (fractionally-strided convolution), then applies a standard convolution
2332/// with the kernel flipped along both spatial axes.
2333///
2334/// # Shape
2335///
2336/// - Input: `[B, in_channels, H, W]`
2337/// - Output: `[B, out_channels, H_out, W_out]`
2338///
2339/// where `H_out = (H - 1) * stride.0 - 2 * padding.0 + kernel_size.0 + output_padding.0`.
2340#[derive(Debug)]
2341pub struct ConvTranspose2d<T: Float> {
2342    /// Learnable kernel weights `[in_channels, out_channels / groups, kH, kW]`.
2343    ///
2344    /// Note: the channel ordering is transposed compared to `Conv2d`
2345    /// (`in_channels` first), per `torch/nn/modules/conv.py:161-167`.
2346    weight: Parameter<T>,
2347    /// Optional learnable bias `[out_channels]`.
2348    bias: Option<Parameter<T>>,
2349    /// Number of input channels.
2350    in_channels: usize,
2351    /// Number of output channels.
2352    out_channels: usize,
2353    /// Kernel spatial size `(kH, kW)`.
2354    kernel_size: (usize, usize),
2355    /// Stride `(sH, sW)`.
2356    stride: (usize, usize),
2357    /// Zero-padding `(pH, pW)` removed from both sides of the output.
2358    padding: (usize, usize),
2359    /// Additional size added to one side of the output `(opH, opW)`.
2360    output_padding: (usize, usize),
2361    /// Dilation `(dilH, dilW)`. `(1, 1)` is the dense default. Spaces the
2362    /// kernel taps in the internal stride-1 convolution
2363    /// (`torch/nn/modules/conv.py:1198-1207`, `dilation` arg of
2364    /// `F.conv_transpose2d`).
2365    dilation: (usize, usize),
2366    /// Number of blocked input/output channel groups. `1` is dense. Must divide
2367    /// both `in_channels` and `out_channels`. The transposed weight is
2368    /// `[in_channels, out_channels / groups, kH, kW]`; per group the input is
2369    /// partitioned on the channel axis (`in_channels / groups` each) and each
2370    /// slab is convolved-transposed with its `[in/groups, out/groups, kH, kW]`
2371    /// weight slab, the outputs concatenated on the channel axis — exactly
2372    /// `aten/src/ATen/native/Convolution.cpp:1723-1729`.
2373    groups: usize,
2374    /// Whether the module is in training mode.
2375    training: bool,
2376}
2377
2378impl<T: Float> ConvTranspose2d<T> {
2379    /// Create a new `ConvTranspose2d` layer (dense, dilation `(1, 1)`,
2380    /// `groups = 1`).
2381    ///
2382    /// Weight is initialized with Kaiming uniform (ReLU gain).
2383    /// Bias, if enabled, is initialized U(-bound, bound) with
2384    /// `bound = 1/sqrt(fan_in)` per `torch/nn/modules/conv.py:198-201`.
2385    ///
2386    /// Thin shim over [`ConvTranspose2d::new_full`] preserved for the existing
2387    /// `new(.., bias)` callers (e.g. `ferrotorch-vision` detection heads).
2388    pub fn new(
2389        in_channels: usize,
2390        out_channels: usize,
2391        kernel_size: (usize, usize),
2392        stride: (usize, usize),
2393        padding: (usize, usize),
2394        output_padding: (usize, usize),
2395        bias: bool,
2396    ) -> FerrotorchResult<Self> {
2397        Self::new_full(
2398            in_channels,
2399            out_channels,
2400            kernel_size,
2401            stride,
2402            padding,
2403            output_padding,
2404            (1, 1),
2405            1,
2406            bias,
2407        )
2408    }
2409
2410    /// Create a new `ConvTranspose2d` with the full PyTorch-shaped argument set,
2411    /// including `dilation` and `groups`.
2412    ///
2413    /// Mirrors `torch.nn.ConvTranspose2d.__init__` (`torch/nn/modules/conv.py:
2414    /// 1133-1167`): the argument order is `(in, out, kernel, stride, padding,
2415    /// output_padding, dilation, groups, bias)`. `groups` must divide BOTH
2416    /// `in_channels` and `out_channels` (upstream `_ConvNd.__init__` raises
2417    /// `ValueError` otherwise, `conv.py:105-110`). The transposed weight layout
2418    /// is `[in_channels, out_channels / groups, kH, kW]` (`conv.py:161-167`).
2419    #[allow(clippy::too_many_arguments)]
2420    pub fn new_full(
2421        in_channels: usize,
2422        out_channels: usize,
2423        kernel_size: (usize, usize),
2424        stride: (usize, usize),
2425        padding: (usize, usize),
2426        output_padding: (usize, usize),
2427        dilation: (usize, usize),
2428        groups: usize,
2429        bias: bool,
2430    ) -> FerrotorchResult<Self> {
2431        if in_channels == 0 || out_channels == 0 {
2432            return Err(FerrotorchError::InvalidArgument {
2433                message: "in_channels and out_channels must be > 0".into(),
2434            });
2435        }
2436        if kernel_size.0 == 0 || kernel_size.1 == 0 {
2437            return Err(FerrotorchError::InvalidArgument {
2438                message: "kernel_size must be > 0 in both dimensions".into(),
2439            });
2440        }
2441        if stride.0 == 0 || stride.1 == 0 {
2442            return Err(FerrotorchError::InvalidArgument {
2443                message: "stride must be > 0 in both dimensions".into(),
2444            });
2445        }
2446        if dilation.0 == 0 || dilation.1 == 0 {
2447            return Err(FerrotorchError::InvalidArgument {
2448                message: "dilation must be > 0 in both dimensions".into(),
2449            });
2450        }
2451        // `_ConvNd.__init__` (`conv.py:105-110`): groups must be positive and
2452        // divide both in_channels and out_channels.
2453        if groups == 0 {
2454            return Err(FerrotorchError::InvalidArgument {
2455                message: "groups must be a positive integer".into(),
2456            });
2457        }
2458        if in_channels % groups != 0 {
2459            return Err(FerrotorchError::InvalidArgument {
2460                message: format!(
2461                    "in_channels ({in_channels}) must be divisible by groups ({groups})"
2462                ),
2463            });
2464        }
2465        if out_channels % groups != 0 {
2466            return Err(FerrotorchError::InvalidArgument {
2467                message: format!(
2468                    "out_channels ({out_channels}) must be divisible by groups ({groups})"
2469                ),
2470            });
2471        }
2472        // `output_padding` must be strictly less than `max(stride, dilation)`
2473        // (upstream `_output_padding` valid-range check, `conv.py:803-822`).
2474        if output_padding.0 >= stride.0.max(dilation.0)
2475            || output_padding.1 >= stride.1.max(dilation.1)
2476        {
2477            return Err(FerrotorchError::InvalidArgument {
2478                message: "output_padding must be strictly less than max(stride, dilation)".into(),
2479            });
2480        }
2481
2482        // Weight shape: [in_channels, out_channels / groups, kH, kW] (transposed
2483        // layout per `torch/nn/modules/conv.py:161-167`).
2484        let (kh, kw) = kernel_size;
2485        let out_per_group = out_channels / groups;
2486        let mut weight = Parameter::zeros(&[in_channels, out_per_group, kh, kw])?;
2487        kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
2488
2489        let bias_param = if bias {
2490            let mut b = Parameter::zeros(&[out_channels])?;
2491            // `torch/nn/modules/conv.py:198-201`: bias U(-bound, bound) with
2492            //   `bound = 1 / sqrt(fan_in)`. For ConvTranspose2d weight shape
2493            //   `[in_channels, out_channels/groups, kH, kW]`,
2494            //   `_calculate_fan_in_and_fan_out` yields
2495            //   `fan_in = (out_channels/groups) * kH * kW`.
2496            let fan_in = out_per_group * kh * kw;
2497            let bound = if fan_in > 0 {
2498                1.0 / (fan_in as f64).sqrt()
2499            } else {
2500                0.0
2501            };
2502            uniform_init(&mut b, -bound, bound)?;
2503            Some(b)
2504        } else {
2505            None
2506        };
2507
2508        Ok(Self {
2509            weight,
2510            bias: bias_param,
2511            in_channels,
2512            out_channels,
2513            kernel_size,
2514            stride,
2515            padding,
2516            output_padding,
2517            dilation,
2518            groups,
2519            training: true,
2520        })
2521    }
2522
2523    /// Configure the boundary handling for the spatial padding.
2524    ///
2525    /// Only [`crate::padding::PaddingMode::Zeros`] is accepted: upstream
2526    /// `_ConvTransposeNd.__init__` raises
2527    /// `ValueError('Only "zeros" padding mode is supported for ConvTranspose2d')`
2528    /// for any non-`zeros` mode (`torch/nn/modules/conv.py:755-758`). This
2529    /// matches that behaviour by returning an error rather than silently
2530    /// accepting the unsupported mode (R-DEV-2). The returned layer is
2531    /// unchanged (the only valid mode is `Zeros`, the constructed default).
2532    /// Closes #1443.
2533    pub fn with_padding_mode(self, mode: crate::padding::PaddingMode) -> FerrotorchResult<Self> {
2534        reject_non_zeros_transpose(mode, "ConvTranspose2d")?;
2535        Ok(self)
2536    }
2537
2538    /// The number of learnable scalar parameters.
2539    pub fn num_parameters(&self) -> usize {
2540        let w = self.in_channels * self.out_channels * self.kernel_size.0 * self.kernel_size.1;
2541        let b = if self.bias.is_some() {
2542            self.out_channels
2543        } else {
2544            0
2545        };
2546        w + b
2547    }
2548
2549    /// Build a `ConvTranspose2d` from caller-supplied weight and optional bias.
2550    ///
2551    /// `weight` must have shape `[in_channels, out_channels, kH, kW]` (note the
2552    /// transposed channel ordering vs `Conv2d`). Used by
2553    /// `nn::functional::conv_transpose2d`.
2554    pub fn from_parts(
2555        weight: Tensor<T>,
2556        bias: Option<Tensor<T>>,
2557        stride: (usize, usize),
2558        padding: (usize, usize),
2559        output_padding: (usize, usize),
2560    ) -> FerrotorchResult<Self> {
2561        if weight.ndim() != 4 {
2562            return Err(FerrotorchError::ShapeMismatch {
2563                message: format!(
2564                    "ConvTranspose2d::from_parts: weight must be 4-D [in, out, kH, kW], got {:?}",
2565                    weight.shape()
2566                ),
2567            });
2568        }
2569        let in_channels = weight.shape()[0];
2570        let out_channels = weight.shape()[1];
2571        let kernel_size = (weight.shape()[2], weight.shape()[3]);
2572        if output_padding.0 >= stride.0 || output_padding.1 >= stride.1 {
2573            return Err(FerrotorchError::InvalidArgument {
2574                message: "output_padding must be strictly less than stride".into(),
2575            });
2576        }
2577        if let Some(b) = &bias {
2578            if b.ndim() != 1 || b.shape()[0] != out_channels {
2579                return Err(FerrotorchError::ShapeMismatch {
2580                    message: format!(
2581                        "ConvTranspose2d::from_parts: bias shape {:?} != [{}]",
2582                        b.shape(),
2583                        out_channels
2584                    ),
2585                });
2586            }
2587        }
2588        Ok(Self {
2589            weight: Parameter::new(weight),
2590            bias: bias.map(Parameter::new),
2591            in_channels,
2592            out_channels,
2593            kernel_size,
2594            stride,
2595            padding,
2596            output_padding,
2597            // `from_parts` recovers `out_channels` from the weight's dim 1, which
2598            // for a grouped weight is `out_channels / groups`; the group count
2599            // cannot be inferred from the weight shape alone, so this builder is
2600            // dense (`groups = 1`, `dilation = (1, 1)`) — matching the
2601            // ABI-compatible `Conv2d::from_parts` policy. Grouped/dilated
2602            // transposed convs go through `new_full`.
2603            dilation: (1, 1),
2604            groups: 1,
2605            training: true,
2606        })
2607    }
2608}
2609
2610/// Insert `(stride - 1)` zeros between each element along both spatial axes.
2611///
2612/// Given input `[B, C, H, W]`, produces `[B, C, H_up, W_up]` where
2613/// `H_up = (H - 1) * stride_h + 1` and `W_up = (W - 1) * stride_w + 1`.
2614fn stride_insert_zeros<T: Float>(
2615    input: &[T],
2616    batch: usize,
2617    channels: usize,
2618    h: usize,
2619    w: usize,
2620    stride_h: usize,
2621    stride_w: usize,
2622) -> (Vec<T>, usize, usize) {
2623    let h_up = (h - 1) * stride_h + 1;
2624    let w_up = (w - 1) * stride_w + 1;
2625    let zero = <T as num_traits::Zero>::zero();
2626    let mut out = vec![zero; batch * channels * h_up * w_up];
2627
2628    for b in 0..batch {
2629        for c in 0..channels {
2630            for ih in 0..h {
2631                for iw in 0..w {
2632                    let oh = ih * stride_h;
2633                    let ow = iw * stride_w;
2634                    out[b * channels * h_up * w_up + c * h_up * w_up + oh * w_up + ow] =
2635                        input[b * channels * h * w + c * h * w + ih * w + iw];
2636                }
2637            }
2638        }
2639    }
2640
2641    (out, h_up, w_up)
2642}
2643
2644/// Crop a `[batch, channels, H, W]` plane by `crop_*` elements off BOTH ends of
2645/// each spatial axis (the 2-D analogue of `crop_volume_3d`). Used by the
2646/// transposed-conv forward when the internal padding `dilation*(k-1) - padding`
2647/// is negative; see `crop_volume_3d` for the upstream `col2vol`
2648/// (`aten/src/ATen/native/vol2col.h:80-106`) correspondence. Callers guarantee
2649/// `2*crop_* < extent`.
2650fn crop_plane_2d<T: Float>(
2651    input: &[T],
2652    batch: usize,
2653    channels: usize,
2654    h: usize,
2655    w: usize,
2656    crop_h: usize,
2657    crop_w: usize,
2658) -> (Vec<T>, usize, usize) {
2659    let h_out = h - 2 * crop_h;
2660    let w_out = w - 2 * crop_w;
2661    let zero = <T as num_traits::Zero>::zero();
2662    let mut out = vec![zero; batch * channels * h_out * w_out];
2663
2664    for b in 0..batch {
2665        for c in 0..channels {
2666            for oh in 0..h_out {
2667                let src = ((b * channels + c) * h + (oh + crop_h)) * w + crop_w;
2668                let dst = ((b * channels + c) * h_out + oh) * w_out;
2669                out[dst..dst + w_out].copy_from_slice(&input[src..src + w_out]);
2670            }
2671        }
2672    }
2673
2674    (out, h_out, w_out)
2675}
2676
2677/// Flip a kernel along both spatial axes: `kernel[c_in, c_out, kh, kw]` ->
2678/// `kernel[c_out, c_in, kH-1-kh, kW-1-kw]` (also transposes channel dims).
2679fn flip_kernel<T: Float>(kernel: &[T], c_in: usize, c_out: usize, kh: usize, kw: usize) -> Vec<T> {
2680    let zero = <T as num_traits::Zero>::zero();
2681    let mut flipped = vec![zero; c_out * c_in * kh * kw];
2682
2683    for ci in 0..c_in {
2684        for co in 0..c_out {
2685            for h in 0..kh {
2686                for w in 0..kw {
2687                    // Source: [c_in, c_out, h, w]
2688                    let src = ci * c_out * kh * kw + co * kh * kw + h * kw + w;
2689                    // Dest: [c_out, c_in, kH-1-h, kW-1-w]
2690                    let dst = co * c_in * kh * kw + ci * kh * kw + (kh - 1 - h) * kw + (kw - 1 - w);
2691                    flipped[dst] = kernel[src];
2692                }
2693            }
2694        }
2695    }
2696
2697    flipped
2698}
2699
2700/// Single-group transposed 2-D convolution forward (the `groups == 1` core).
2701///
2702/// Operates on an already channel-sliced input slab `[B, in_pg, H, W]` and a
2703/// weight slab `[in_pg, out_pg, kH, kW]` (the transposed grouped layout,
2704/// `torch/nn/modules/conv.py:164`), returning the per-group output buffer
2705/// `[B, out_pg, h_out, w_out]`. Algorithm: stride-insert zeros, append the
2706/// `output_padding` boundary, flip+transpose the kernel, then run a stride-1
2707/// `dilation`-spaced internal convolution (`internal_pad = dilation*(k-1) -
2708/// padding`). This is the same math the dense ConvTranspose2d used (#1560),
2709/// now generalised for `dilation` via `im2col_dilated`.
2710// Internal kernel: the argument set mirrors the 2-D transposed-conv descriptor;
2711// a config struct would force allocation in the per-group hot loop.
2712#[allow(clippy::too_many_arguments)]
2713fn conv_transpose2d_forward_group<T: Float>(
2714    input_data: &[T],
2715    batch: usize,
2716    in_pg: usize,
2717    out_pg: usize,
2718    h: usize,
2719    w: usize,
2720    kernel_size: (usize, usize),
2721    stride: (usize, usize),
2722    padding: (usize, usize),
2723    output_padding: (usize, usize),
2724    dilation: (usize, usize),
2725    group_weight: &[T],
2726) -> FerrotorchResult<(Vec<T>, usize, usize)> {
2727    let (kh, kw) = kernel_size;
2728    let (sh, sw) = stride;
2729    let (ph, pw) = padding;
2730    let (oph, opw) = output_padding;
2731    let (dh, dw) = dilation;
2732
2733    // Step 1: stride-insert zeros, then append the `output_padding` boundary.
2734    let (upsampled, h_up_core, w_up_core) =
2735        stride_insert_zeros(input_data, batch, in_pg, h, w, sh, sw);
2736    let h_up = h_up_core + oph;
2737    let w_up = w_up_core + opw;
2738    let upsampled = if oph > 0 || opw > 0 {
2739        let zero = <T as num_traits::Zero>::zero();
2740        let mut ext = vec![zero; batch * in_pg * h_up * w_up];
2741        for b in 0..batch {
2742            for c in 0..in_pg {
2743                for ih in 0..h_up_core {
2744                    let src = ((b * in_pg + c) * h_up_core + ih) * w_up_core;
2745                    let dst = ((b * in_pg + c) * h_up + ih) * w_up;
2746                    ext[dst..dst + w_up_core].copy_from_slice(&upsampled[src..src + w_up_core]);
2747                }
2748            }
2749        }
2750        ext
2751    } else {
2752        upsampled
2753    };
2754
2755    // Step 2: flip the kernel and transpose channel dimensions.
2756    // Group weight: [in_pg, out_pg, kH, kW] -> flipped [out_pg, in_pg, kH, kW]
2757    // with a spatial flip (the regular-conv adjoint of the transposed conv).
2758    let flipped = flip_kernel(group_weight, in_pg, out_pg, kh, kw);
2759
2760    // Step 3: regular `dilation`-spaced stride-1 convolution on the upsampled
2761    // signal. The internal padding is `dilation*(kernel-1) - padding`, the
2762    // dilated generalisation of the dense `kernel-1-padding` (#1560). The
2763    // `eff_k = dilation*(k-1)+1` kernel taps are spaced by `dilation` in
2764    // `im2col_dilated`, mirroring `ConvUtils.h:255`. When `padding >
2765    // dilation*(k-1)` the internal pad is NEGATIVE and the signal is CROPPED
2766    // rather than zero-padded (a negative `usize` would wrap and silently drop
2767    // the scatter); see `conv_transpose3d_forward_group` / `crop_volume_3d`
2768    // for the upstream `col2vol` (`aten/src/ATen/native/vol2col.h:80-106`)
2769    // correspondence.
2770    let eff_kh = dh * (kh - 1) + 1;
2771    let eff_kw = dw * (kw - 1) + 1;
2772    let signed_pad_h = (eff_kh - 1) as isize - ph as isize;
2773    let signed_pad_w = (eff_kw - 1) as isize - pw as isize;
2774    let crop_h = (-signed_pad_h).max(0) as usize;
2775    let crop_w = (-signed_pad_w).max(0) as usize;
2776    let (conv_input, h_in, w_in) = if crop_h > 0 || crop_w > 0 {
2777        crop_plane_2d(&upsampled, batch, in_pg, h_up, w_up, crop_h, crop_w)
2778    } else {
2779        (upsampled, h_up, w_up)
2780    };
2781    let internal_pad_h = signed_pad_h.max(0) as usize;
2782    let internal_pad_w = signed_pad_w.max(0) as usize;
2783
2784    let (cols, col_rows, col_cols) = im2col_dilated(
2785        &conv_input,
2786        batch,
2787        in_pg,
2788        h_in,
2789        w_in,
2790        kh,
2791        kw,
2792        1,
2793        1,
2794        internal_pad_h,
2795        internal_pad_w,
2796        dh,
2797        dw,
2798    );
2799
2800    // Internal stride-1 conv output size over the `output_padding`-extended,
2801    // dilation-spaced signal.
2802    let h_out = (h_in + 2 * internal_pad_h - eff_kh) + 1;
2803    let w_out = (w_in + 2 * internal_pad_w - eff_kw) + 1;
2804
2805    // Reshape flipped kernel to 2-D: [out_pg, in_pg * kH * kW].
2806    let flipped_2d =
2807        Tensor::from_storage(TensorStorage::cpu(flipped), vec![out_pg, col_rows], false)?;
2808
2809    let zero = <T as num_traits::Zero>::zero();
2810    let mut output = vec![zero; batch * out_pg * h_out * w_out];
2811
2812    for b in 0..batch {
2813        let col_start = b * col_rows * col_cols;
2814        let col_end = col_start + col_rows * col_cols;
2815        let cols_b = Tensor::from_storage(
2816            TensorStorage::cpu(cols[col_start..col_end].to_vec()),
2817            vec![col_rows, col_cols],
2818            false,
2819        )?;
2820
2821        let out_b = mm(&flipped_2d, &cols_b)?;
2822        let out_data = out_b.data()?;
2823
2824        let out_start = b * out_pg * h_out * w_out;
2825        let copy_len = out_pg * h_out * w_out;
2826        output[out_start..out_start + copy_len].copy_from_slice(&out_data[..copy_len]);
2827    }
2828
2829    Ok((output, h_out, w_out))
2830}
2831
2832impl<T: Float> Module<T> for ConvTranspose2d<T> {
2833    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2834        // Record autocast decision for conv_transpose2d.
2835        let _autocast_cat = autocast_guard("conv_transpose2d");
2836
2837        // Unbatched input: `(C, H, W)` (rank 3) is accepted in addition to the
2838        // batched `(N, C, H, W)` (rank 4) form. Mirrors `batchify` at
2839        // `aten/src/ATen/native/Convolution.cpp:1197` (conv_transpose2d): an
2840        // unbatched input is `unsqueeze(0)`d to add a batch dim, transposed-
2841        // convolved, then `squeeze(0)`d so the output is rank 3. The
2842        // unsqueeze/squeeze are autograd-aware (`UnsqueezeBackward` /
2843        // `SqueezeBackward`) so gradients flow back to the unbatched shape.
2844        // Closes #1609.
2845        if input.ndim() == 3 {
2846            let batched = unsqueeze(input, 0)?;
2847            let output = self.forward(&batched)?;
2848            return squeeze(&output, 0);
2849        }
2850
2851        // Validate input shape: [B, C_in, H, W].
2852        if input.ndim() != 4 {
2853            return Err(FerrotorchError::InvalidArgument {
2854                message: format!(
2855                    "ConvTranspose2d expects 4-D input [B, C, H, W], got {:?}",
2856                    input.shape()
2857                ),
2858            });
2859        }
2860
2861        let batch = input.shape()[0];
2862        let c_in = input.shape()[1];
2863        let h = input.shape()[2];
2864        let w = input.shape()[3];
2865
2866        if c_in != self.in_channels {
2867            return Err(FerrotorchError::ShapeMismatch {
2868                message: format!(
2869                    "ConvTranspose2d: expected {} input channels, got {}",
2870                    self.in_channels, c_in
2871                ),
2872            });
2873        }
2874
2875        let (kh, kw) = self.kernel_size;
2876        let groups = self.groups;
2877        let in_pg = self.in_channels / groups;
2878        let out_pg = self.out_channels / groups;
2879        let weight_pg_numel = in_pg * out_pg * kh * kw;
2880
2881        // Save the input device so we can restore it on the output.
2882        let input_device = input.device();
2883
2884        let input_data = input.data_vec()?;
2885        let weight_data = self.weight.data_vec()?;
2886
2887        // Per-group transposed convolution, mirroring the SlowTranspose grouped
2888        // loop at `aten/src/ATen/native/Convolution.cpp:1723-1729`: partition
2889        // the input on the channel axis (`in_pg` per group), the weight on dim
2890        // 0 (`in_pg` per group, giving the `[in_pg, out_pg, kH, kW]` slab), and
2891        // (later) the bias on dim 0 (`out_pg` per group); convolve-transpose
2892        // each group and concatenate the outputs on the channel axis.
2893        let zero = <T as num_traits::Zero>::zero();
2894        let mut output: Vec<T> = Vec::new();
2895        let mut h_out = 0usize;
2896        let mut w_out = 0usize;
2897
2898        for g in 0..groups {
2899            // Slice this group's input channels: [B, in_pg, H, W].
2900            let mut group_input = vec![zero; batch * in_pg * h * w];
2901            for b in 0..batch {
2902                for c in 0..in_pg {
2903                    let src_c = g * in_pg + c;
2904                    let src_start = b * self.in_channels * h * w + src_c * h * w;
2905                    let dst_start = b * in_pg * h * w + c * h * w;
2906                    group_input[dst_start..dst_start + h * w]
2907                        .copy_from_slice(&input_data[src_start..src_start + h * w]);
2908                }
2909            }
2910
2911            // Slice this group's weight slab: [in_pg, out_pg, kH, kW]. The
2912            // transposed weight is `[in_channels, out_pg, kH, kW]`, so group g
2913            // owns rows `[g*in_pg, (g+1)*in_pg)` of dim 0 — contiguous since
2914            // dim 0 is the outermost axis.
2915            let w_start = g * weight_pg_numel;
2916            let group_weight = &weight_data[w_start..w_start + weight_pg_numel];
2917
2918            let (g_out, gho, gwo) = conv_transpose2d_forward_group(
2919                &group_input,
2920                batch,
2921                in_pg,
2922                out_pg,
2923                h,
2924                w,
2925                self.kernel_size,
2926                self.stride,
2927                self.padding,
2928                self.output_padding,
2929                self.dilation,
2930                group_weight,
2931            )?;
2932            h_out = gho;
2933            w_out = gwo;
2934
2935            if output.is_empty() {
2936                output = vec![zero; batch * self.out_channels * h_out * w_out];
2937            }
2938            // Place this group's `out_pg` channels at `[b, g*out_pg.., :, :]`.
2939            for b in 0..batch {
2940                for oc in 0..out_pg {
2941                    let dst_c = g * out_pg + oc;
2942                    let dst_start = b * self.out_channels * h_out * w_out + dst_c * h_out * w_out;
2943                    let src_start = (b * out_pg + oc) * h_out * w_out;
2944                    output[dst_start..dst_start + h_out * w_out]
2945                        .copy_from_slice(&g_out[src_start..src_start + h_out * w_out]);
2946                }
2947            }
2948        }
2949
2950        // Add bias if present.
2951        if let Some(ref bias) = self.bias {
2952            let bias_data = bias.data_vec()?;
2953            for b in 0..batch {
2954                for c in 0..self.out_channels {
2955                    let bval = bias_data[c];
2956                    for hw in 0..(h_out * w_out) {
2957                        output[b * self.out_channels * h_out * w_out + c * h_out * w_out + hw] +=
2958                            bval;
2959                    }
2960                }
2961            }
2962        }
2963
2964        let result = Tensor::from_storage(
2965            TensorStorage::cpu(output),
2966            vec![batch, self.out_channels, h_out, w_out],
2967            false,
2968        )?;
2969
2970        // Attach backward if gradients are enabled.
2971        if is_grad_enabled()
2972            && (input.requires_grad()
2973                || self.weight.requires_grad()
2974                || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
2975        {
2976            let grad_fn = Arc::new(ConvTranspose2dBackward {
2977                input: input.clone(),
2978                weight: self.weight.tensor().clone(),
2979                bias: self.bias.as_ref().map(|b| b.tensor().clone()),
2980                in_channels: self.in_channels,
2981                out_channels: self.out_channels,
2982                kernel_size: self.kernel_size,
2983                stride: self.stride,
2984                padding: self.padding,
2985                _output_padding: self.output_padding,
2986                dilation: self.dilation,
2987                groups: self.groups,
2988                h_out,
2989                w_out,
2990            });
2991            Tensor::from_operation(
2992                TensorStorage::cpu(result.data()?.to_vec()),
2993                result.shape().to_vec(),
2994                grad_fn,
2995            )?
2996            .to(input_device) // restore device
2997        } else {
2998            result.to(input_device)
2999        }
3000    }
3001
3002    fn parameters(&self) -> Vec<&Parameter<T>> {
3003        let mut params = vec![&self.weight];
3004        if let Some(ref b) = self.bias {
3005            params.push(b);
3006        }
3007        params
3008    }
3009
3010    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
3011        let mut params = vec![&mut self.weight];
3012        if let Some(ref mut b) = self.bias {
3013            params.push(b);
3014        }
3015        params
3016    }
3017
3018    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
3019        let mut params = vec![("weight".to_string(), &self.weight)];
3020        if let Some(ref b) = self.bias {
3021            params.push(("bias".to_string(), b));
3022        }
3023        params
3024    }
3025
3026    fn train(&mut self) {
3027        self.training = true;
3028    }
3029
3030    fn eval(&mut self) {
3031        self.training = false;
3032    }
3033
3034    fn is_training(&self) -> bool {
3035        self.training
3036    }
3037}
3038
3039// ---------------------------------------------------------------------------
3040// ConvTranspose2dBackward
3041// ---------------------------------------------------------------------------
3042
3043/// Backward function for `ConvTranspose2d` forward pass.
3044///
3045/// The backward of a transposed convolution is a regular convolution.
3046#[derive(Debug)]
3047struct ConvTranspose2dBackward<T: Float> {
3048    input: Tensor<T>,
3049    weight: Tensor<T>,
3050    bias: Option<Tensor<T>>,
3051    in_channels: usize,
3052    out_channels: usize,
3053    kernel_size: (usize, usize),
3054    stride: (usize, usize),
3055    padding: (usize, usize),
3056    _output_padding: (usize, usize),
3057    dilation: (usize, usize),
3058    groups: usize,
3059    h_out: usize,
3060    w_out: usize,
3061}
3062
3063impl<T: Float> GradFn<T> for ConvTranspose2dBackward<T> {
3064    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
3065        // grad_output shape: [B, C_out, H_out, W_out]
3066        let go_data = grad_output.data_vec()?;
3067        let batch = self.input.shape()[0];
3068        let h_in = self.input.shape()[2];
3069        let w_in = self.input.shape()[3];
3070        let (kh, kw) = self.kernel_size;
3071        let (sh, sw) = self.stride;
3072        let (ph, pw) = self.padding;
3073        let (dh, dw) = self.dilation;
3074        let groups = self.groups;
3075        let in_pg = self.in_channels / groups;
3076        let out_pg = self.out_channels / groups;
3077        let zero = <T as num_traits::Zero>::zero();
3078
3079        let weight_data_all = self.weight.data_vec()?;
3080        let input_data_all = self.input.data_vec()?;
3081
3082        // The grad_input / grad_weight of a transposed convolution decompose
3083        // per-group exactly like the forward (Convolution.cpp:1723-1729 +
3084        // 2282-2297 grouped backward): for group g the relevant channels are
3085        // in `[g*in_pg, (g+1)*in_pg)` (input) and `[g*out_pg, (g+1)*out_pg)`
3086        // (grad_output), using the weight slab `[in_pg, out_pg, kH, kW]`.
3087        let mut gi_all = if self.input.requires_grad() {
3088            Some(vec![zero; batch * self.in_channels * h_in * w_in])
3089        } else {
3090            None
3091        };
3092        let mut gw_all = if self.weight.requires_grad() {
3093            Some(vec![zero; self.in_channels * out_pg * kh * kw])
3094        } else {
3095            None
3096        };
3097
3098        for g in 0..groups {
3099            // --- grad_input (group g) ---
3100            // The backward of a transposed conv w.r.t. input is a regular
3101            // (forward) convolution of grad_output with the original
3102            // (non-flipped) weight, dilation-spaced. Weight slab is
3103            // [in_pg, out_pg, kH, kW] reshaped to [in_pg, out_pg*kH*kW].
3104            if let Some(gi) = gi_all.as_mut() {
3105                let col_rows = out_pg * kh * kw;
3106                let w_start = g * in_pg * out_pg * kh * kw;
3107                let weight_2d = Tensor::from_storage(
3108                    TensorStorage::cpu(
3109                        weight_data_all[w_start..w_start + in_pg * out_pg * kh * kw].to_vec(),
3110                    ),
3111                    vec![in_pg, col_rows],
3112                    false,
3113                )?;
3114
3115                // Slice this group's grad_output channels: [B, out_pg, H_out, W_out].
3116                let mut go_g = vec![zero; batch * out_pg * self.h_out * self.w_out];
3117                for b in 0..batch {
3118                    for c in 0..out_pg {
3119                        let src_c = g * out_pg + c;
3120                        let src = (b * self.out_channels + src_c) * self.h_out * self.w_out;
3121                        let dst = (b * out_pg + c) * self.h_out * self.w_out;
3122                        go_g[dst..dst + self.h_out * self.w_out]
3123                            .copy_from_slice(&go_data[src..src + self.h_out * self.w_out]);
3124                    }
3125                }
3126
3127                let (go_cols, _gcr, go_col_cols) = im2col_dilated(
3128                    &go_g, batch, out_pg, self.h_out, self.w_out, kh, kw, sh, sw, ph, pw, dh, dw,
3129                );
3130                debug_assert_eq!(go_col_cols, h_in * w_in);
3131
3132                for b in 0..batch {
3133                    let col_start = b * col_rows * go_col_cols;
3134                    let col_end = col_start + col_rows * go_col_cols;
3135                    let go_cols_b = Tensor::from_storage(
3136                        TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
3137                        vec![col_rows, go_col_cols],
3138                        false,
3139                    )?;
3140                    let gi_b = mm(&weight_2d, &go_cols_b)?;
3141                    let gi_data = gi_b.data()?;
3142                    // Scatter the group's in_pg channels back into the dense input grad.
3143                    for c in 0..in_pg {
3144                        let dst_c = g * in_pg + c;
3145                        let dst = (b * self.in_channels + dst_c) * h_in * w_in;
3146                        let src = c * h_in * w_in;
3147                        gi[dst..dst + h_in * w_in]
3148                            .copy_from_slice(&gi_data[src..src + h_in * w_in]);
3149                    }
3150                }
3151            }
3152
3153            // --- grad_weight (group g) ---
3154            // grad_weight[ci, co, kh, kw] = sum_b input[ci] cross-correlated
3155            // with grad_output[co], the kernel tap at `dilation`-spaced offset.
3156            if let Some(gw) = gw_all.as_mut() {
3157                for ci in 0..in_pg {
3158                    let in_c = g * in_pg + ci;
3159                    for co in 0..out_pg {
3160                        let out_c = g * out_pg + co;
3161                        for tkh in 0..kh {
3162                            for tkw in 0..kw {
3163                                let mut acc = zero;
3164                                for ih in 0..h_in {
3165                                    for iw in 0..w_in {
3166                                        let oh = ih * sh + tkh * dh;
3167                                        let ow = iw * sw + tkw * dw;
3168                                        if oh >= ph
3169                                            && ow >= pw
3170                                            && (oh - ph) < self.h_out
3171                                            && (ow - pw) < self.w_out
3172                                        {
3173                                            let go_index = (out_c * self.h_out + (oh - ph))
3174                                                * self.w_out
3175                                                + (ow - pw);
3176                                            let in_index = (in_c * h_in + ih) * w_in + iw;
3177                                            // Sum over the batch.
3178                                            for b in 0..batch {
3179                                                let goi =
3180                                                    b * self.out_channels * self.h_out * self.w_out
3181                                                        + go_index;
3182                                                let ini =
3183                                                    b * self.in_channels * h_in * w_in + in_index;
3184                                                acc += input_data_all[ini] * go_data[goi];
3185                                            }
3186                                        }
3187                                    }
3188                                }
3189                                // gw layout: [in_channels, out_pg, kH, kW].
3190                                gw[((in_c * out_pg + co) * kh + tkh) * kw + tkw] += acc;
3191                            }
3192                        }
3193                    }
3194                }
3195            }
3196        }
3197
3198        let grad_input = match gi_all {
3199            Some(gi) => Some(Tensor::from_storage(
3200                TensorStorage::cpu(gi),
3201                self.input.shape().to_vec(),
3202                false,
3203            )?),
3204            None => None,
3205        };
3206        let grad_weight = match gw_all {
3207            Some(gw) => Some(Tensor::from_storage(
3208                TensorStorage::cpu(gw),
3209                vec![self.in_channels, out_pg, kh, kw],
3210                false,
3211            )?),
3212            None => None,
3213        };
3214
3215        // --- grad_bias ---
3216        let grad_bias = match &self.bias {
3217            Some(b) if b.requires_grad() => {
3218                let zero = <T as num_traits::Zero>::zero();
3219                let mut gb = vec![zero; self.out_channels];
3220                for batch_idx in 0..batch {
3221                    for c in 0..self.out_channels {
3222                        for hw in 0..(self.h_out * self.w_out) {
3223                            gb[c] +=
3224                                go_data[batch_idx * self.out_channels * self.h_out * self.w_out
3225                                    + c * self.h_out * self.w_out
3226                                    + hw];
3227                        }
3228                    }
3229                }
3230                Some(Tensor::from_storage(
3231                    TensorStorage::cpu(gb),
3232                    vec![self.out_channels],
3233                    false,
3234                )?)
3235            }
3236            _ => None,
3237        };
3238
3239        let mut grads = vec![grad_input, grad_weight];
3240        if self.bias.is_some() {
3241            grads.push(grad_bias);
3242        }
3243        Ok(grads)
3244    }
3245
3246    fn inputs(&self) -> Vec<&Tensor<T>> {
3247        let mut v = vec![&self.input, &self.weight];
3248        if let Some(ref b) = self.bias {
3249            v.push(b);
3250        }
3251        v
3252    }
3253
3254    fn name(&self) -> &'static str {
3255        "ConvTranspose2dBackward"
3256    }
3257}
3258
3259// ---------------------------------------------------------------------------
3260// im2col_3d / col2im_3d helpers
3261// ---------------------------------------------------------------------------
3262
3263/// Extract volumetric patches into columns, supporting dilation
3264/// `(dil_d, dil_h, dil_w)`.
3265///
3266/// Given a 5-D input `[B, C, D, H, W]`, produces
3267/// `[B, C * kD * kH * kW, D_out * H_out * W_out]` where each column is one
3268/// flattened receptive-field patch with kernel taps spaced by the dilation
3269/// factors. Output spatial sizes follow `out = (in + 2*pad - dil*(k - 1) -
3270/// 1)/stride + 1`, mirroring `ConvUtils.h:255-256`.
3271// Internal kernel: argument set mirrors the 3-D convolution descriptor; a
3272// config struct would force allocation on every call in convolution hot paths.
3273#[allow(clippy::too_many_arguments)]
3274fn im2col_3d_dilated<T: Float>(
3275    input: &[T],
3276    batch: usize,
3277    channels: usize,
3278    depth: usize,
3279    height: usize,
3280    width: usize,
3281    kernel_d: usize,
3282    kernel_h: usize,
3283    kernel_w: usize,
3284    stride_d: usize,
3285    stride_h: usize,
3286    stride_w: usize,
3287    pad_d: usize,
3288    pad_h: usize,
3289    pad_w: usize,
3290    dil_d: usize,
3291    dil_h: usize,
3292    dil_w: usize,
3293) -> (Vec<T>, usize, usize) {
3294    let eff_kd = dil_d * (kernel_d - 1) + 1;
3295    let eff_kh = dil_h * (kernel_h - 1) + 1;
3296    let eff_kw = dil_w * (kernel_w - 1) + 1;
3297    let d_out = (depth + 2 * pad_d - eff_kd) / stride_d + 1;
3298    let h_out = (height + 2 * pad_h - eff_kh) / stride_h + 1;
3299    let w_out = (width + 2 * pad_w - eff_kw) / stride_w + 1;
3300    let col_rows = channels * kernel_d * kernel_h * kernel_w;
3301    let col_cols = d_out * h_out * w_out;
3302
3303    let zero = <T as num_traits::Zero>::zero();
3304    let mut cols = vec![zero; batch * col_rows * col_cols];
3305
3306    for b in 0..batch {
3307        for c in 0..channels {
3308            for kd in 0..kernel_d {
3309                for kh in 0..kernel_h {
3310                    for kw in 0..kernel_w {
3311                        let row = c * kernel_d * kernel_h * kernel_w
3312                            + kd * kernel_h * kernel_w
3313                            + kh * kernel_w
3314                            + kw;
3315                        for od in 0..d_out {
3316                            for oh in 0..h_out {
3317                                for ow in 0..w_out {
3318                                    let id = od * stride_d + kd * dil_d;
3319                                    let ih = oh * stride_h + kh * dil_h;
3320                                    let iw = ow * stride_w + kw * dil_w;
3321                                    let col = od * h_out * w_out + oh * w_out + ow;
3322
3323                                    let val = if id >= pad_d
3324                                        && ih >= pad_h
3325                                        && iw >= pad_w
3326                                        && (id - pad_d) < depth
3327                                        && (ih - pad_h) < height
3328                                        && (iw - pad_w) < width
3329                                    {
3330                                        let real_d = id - pad_d;
3331                                        let real_h = ih - pad_h;
3332                                        let real_w = iw - pad_w;
3333                                        input[b * channels * depth * height * width
3334                                            + c * depth * height * width
3335                                            + real_d * height * width
3336                                            + real_h * width
3337                                            + real_w]
3338                                    } else {
3339                                        zero
3340                                    };
3341
3342                                    cols[b * col_rows * col_cols + row * col_cols + col] = val;
3343                                }
3344                            }
3345                        }
3346                    }
3347                }
3348            }
3349        }
3350    }
3351
3352    (cols, col_rows, col_cols)
3353}
3354
3355/// Scatter columns back into a volume tensor with dilation support
3356/// (adjoint of `im2col_3d_dilated`). The non-dilated 3-D scatter is simply
3357/// this with `(dil_d, dil_h, dil_w) = (1, 1, 1)`; production callers
3358/// (`Conv3dBackward`) always pass the layer's dilation directly, so no
3359/// separate non-dilated shim is kept.
3360// Internal kernel: adjoint of `im2col_3d_dilated`; same descriptor signature.
3361#[allow(clippy::too_many_arguments)]
3362fn col2im_3d_dilated<T: Float>(
3363    cols: &[T],
3364    batch: usize,
3365    channels: usize,
3366    depth: usize,
3367    height: usize,
3368    width: usize,
3369    kernel_d: usize,
3370    kernel_h: usize,
3371    kernel_w: usize,
3372    stride_d: usize,
3373    stride_h: usize,
3374    stride_w: usize,
3375    pad_d: usize,
3376    pad_h: usize,
3377    pad_w: usize,
3378    dil_d: usize,
3379    dil_h: usize,
3380    dil_w: usize,
3381    d_out: usize,
3382    h_out: usize,
3383    w_out: usize,
3384) -> Vec<T> {
3385    let zero = <T as num_traits::Zero>::zero();
3386    let mut output = vec![zero; batch * channels * depth * height * width];
3387
3388    let col_rows = channels * kernel_d * kernel_h * kernel_w;
3389    let col_cols = d_out * h_out * w_out;
3390
3391    for b in 0..batch {
3392        for c in 0..channels {
3393            for kd in 0..kernel_d {
3394                for kh in 0..kernel_h {
3395                    for kw in 0..kernel_w {
3396                        let row = c * kernel_d * kernel_h * kernel_w
3397                            + kd * kernel_h * kernel_w
3398                            + kh * kernel_w
3399                            + kw;
3400                        for od in 0..d_out {
3401                            for oh in 0..h_out {
3402                                for ow in 0..w_out {
3403                                    let id = od * stride_d + kd * dil_d;
3404                                    let ih = oh * stride_h + kh * dil_h;
3405                                    let iw = ow * stride_w + kw * dil_w;
3406                                    let col = od * h_out * w_out + oh * w_out + ow;
3407
3408                                    if id >= pad_d
3409                                        && ih >= pad_h
3410                                        && iw >= pad_w
3411                                        && (id - pad_d) < depth
3412                                        && (ih - pad_h) < height
3413                                        && (iw - pad_w) < width
3414                                    {
3415                                        let real_d = id - pad_d;
3416                                        let real_h = ih - pad_h;
3417                                        let real_w = iw - pad_w;
3418                                        output[b * channels * depth * height * width
3419                                            + c * depth * height * width
3420                                            + real_d * height * width
3421                                            + real_h * width
3422                                            + real_w] +=
3423                                            cols[b * col_rows * col_cols + row * col_cols + col];
3424                                    }
3425                                }
3426                            }
3427                        }
3428                    }
3429                }
3430            }
3431        }
3432    }
3433
3434    output
3435}
3436
3437// ---------------------------------------------------------------------------
3438// Conv3d
3439// ---------------------------------------------------------------------------
3440
3441/// A 3-D convolution layer for volumetric data.
3442///
3443/// Applies a spatial convolution over an input `[B, C_in, D, H, W]` using
3444/// the im2col + matmul algorithm. Equivalent to `torch.nn.Conv3d`.
3445///
3446/// # Shape
3447///
3448/// - Input: `[B, in_channels, D, H, W]`
3449/// - Output: `[B, out_channels, D_out, H_out, W_out]`
3450///
3451/// where `D_out = (D + 2 * padding.0 - kernel_size.0) / stride.0 + 1` (and
3452/// analogously for H and W).
3453#[derive(Debug)]
3454pub struct Conv3d<T: Float> {
3455    /// Learnable kernel weights `[out_channels, in_channels / groups, kD, kH, kW]`.
3456    weight: Parameter<T>,
3457    /// Optional learnable bias `[out_channels]`.
3458    bias: Option<Parameter<T>>,
3459    /// Number of input channels.
3460    in_channels: usize,
3461    /// Number of output channels (filters).
3462    out_channels: usize,
3463    /// Kernel spatial size `(kD, kH, kW)`.
3464    kernel_size: (usize, usize, usize),
3465    /// Stride `(sD, sH, sW)`.
3466    stride: (usize, usize, usize),
3467    /// Zero-padding `(pD, pH, pW)` applied to both sides.
3468    padding: (usize, usize, usize),
3469    /// Dilation `(dilD, dilH, dilW)`. `(1, 1, 1)` is the dense default.
3470    /// Spaces kernel taps apart along each spatial axis (`eff_kernel =
3471    /// dilation * (k - 1) + 1`), mirroring `torch.nn.Conv3d(..., dilation=1)`
3472    /// (`conv.py:689`).
3473    dilation: (usize, usize, usize),
3474    /// Number of blocked input/output channel groups. `1` is dense,
3475    /// `in_channels` is depthwise. Must divide both `in_channels` and
3476    /// `out_channels`, mirroring `torch.nn.Conv3d(..., groups=1)`
3477    /// (`conv.py:690`, validation `conv.py:107-110`).
3478    groups: usize,
3479    /// Boundary handling for the spatial padding. `Zeros` (default) routes
3480    /// through the existing im2col zero-pad path; non-`Zeros` modes pre-pad
3481    /// the input via `crate::padding::functional_pad_3d` and then run the
3482    /// dense im2col over the already-padded tensor (matching the upstream
3483    /// `Conv3d._conv_forward`: `F.pad(input, ..., mode=...)` first, then a
3484    /// `padding=0` convolution). See `torch/nn/modules/conv.py:716-732`.
3485    /// Closes #1443.
3486    padding_mode: crate::padding::PaddingMode,
3487    /// String padding mode (`'same'` / `'valid'`), `None` when numeric
3488    /// `padding` is used. When `Some`, the `padding` field is ignored and the
3489    /// effective padding is derived per [`StringPadding`] in `forward`
3490    /// (mirroring the `padding: str` branch of `torch.nn.Conv3d`,
3491    /// `torch/nn/modules/conv.py:111-155`). Set via
3492    /// [`Conv3d::with_string_padding`]. Closes #1602.
3493    string_padding: Option<StringPadding>,
3494    /// Whether the module is in training mode.
3495    training: bool,
3496}
3497
3498impl<T: Float> Conv3d<T> {
3499    /// Create a new `Conv3d` layer (dense, dilation `(1, 1, 1)`, `groups = 1`).
3500    ///
3501    /// Weight is initialized with Kaiming uniform (ReLU gain).
3502    /// Bias, if enabled, is initialized U(-bound, bound) with
3503    /// `bound = 1/sqrt(fan_in)` per `torch/nn/modules/conv.py:198-201`.
3504    ///
3505    /// This is a thin shim over [`Conv3d::new_full`] preserved for callers
3506    /// that only need the dense configuration (e.g. `LazyConv3d::materialize`).
3507    pub fn new(
3508        in_channels: usize,
3509        out_channels: usize,
3510        kernel_size: (usize, usize, usize),
3511        stride: (usize, usize, usize),
3512        padding: (usize, usize, usize),
3513        bias: bool,
3514    ) -> FerrotorchResult<Self> {
3515        Self::new_full(
3516            in_channels,
3517            out_channels,
3518            kernel_size,
3519            stride,
3520            padding,
3521            (1, 1, 1),
3522            1,
3523            bias,
3524        )
3525    }
3526
3527    /// Create a new `Conv3d` layer with the full PyTorch-shaped argument set,
3528    /// including `dilation` and `groups`.
3529    ///
3530    /// `groups` must divide BOTH `in_channels` and `out_channels` (PyTorch
3531    /// `torch.nn.Conv3d` raises `ValueError` otherwise, `conv.py:107-110`).
3532    /// `dilation` must be strictly positive in all dimensions. Weight is
3533    /// initialised with Kaiming uniform (ReLU gain); bias (if enabled) with
3534    /// U(-bound, bound) where `bound = 1/sqrt(fan_in)`, `fan_in =
3535    /// (in_channels/groups) * kD * kH * kW` per
3536    /// `torch/nn/modules/conv.py:198-201`.
3537    ///
3538    /// Weight layout is `[out_channels, in_channels / groups, kD, kH, kW]`,
3539    /// the PyTorch grouped-conv convention (`conv.py:171`). Argument order
3540    /// `(.., dilation, groups, bias)` mirrors `Conv3d.__init__`
3541    /// (`conv.py:682-691`, R-DEV-2).
3542    #[allow(clippy::too_many_arguments)]
3543    pub fn new_full(
3544        in_channels: usize,
3545        out_channels: usize,
3546        kernel_size: (usize, usize, usize),
3547        stride: (usize, usize, usize),
3548        padding: (usize, usize, usize),
3549        dilation: (usize, usize, usize),
3550        groups: usize,
3551        bias: bool,
3552    ) -> FerrotorchResult<Self> {
3553        if in_channels == 0 || out_channels == 0 {
3554            return Err(FerrotorchError::InvalidArgument {
3555                message: "in_channels and out_channels must be > 0".into(),
3556            });
3557        }
3558        if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
3559            return Err(FerrotorchError::InvalidArgument {
3560                message: "kernel_size must be > 0 in all dimensions".into(),
3561            });
3562        }
3563        if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
3564            return Err(FerrotorchError::InvalidArgument {
3565                message: "stride must be > 0 in all dimensions".into(),
3566            });
3567        }
3568        if dilation.0 == 0 || dilation.1 == 0 || dilation.2 == 0 {
3569            return Err(FerrotorchError::InvalidArgument {
3570                message: format!(
3571                    "Conv3d::new_full: dilation must be > 0 in all dimensions, got {dilation:?}"
3572                ),
3573            });
3574        }
3575        if groups == 0 {
3576            return Err(FerrotorchError::InvalidArgument {
3577                message: "Conv3d::new_full: groups must be > 0".into(),
3578            });
3579        }
3580        // `torch/nn/modules/conv.py:107-110`: `in_channels % groups != 0`
3581        // and `out_channels % groups != 0` each raise ValueError.
3582        if in_channels % groups != 0 {
3583            return Err(FerrotorchError::InvalidArgument {
3584                message: format!(
3585                    "Conv3d::new_full: groups={groups} must divide in_channels={in_channels}"
3586                ),
3587            });
3588        }
3589        if out_channels % groups != 0 {
3590            return Err(FerrotorchError::InvalidArgument {
3591                message: format!(
3592                    "Conv3d::new_full: groups={groups} must divide out_channels={out_channels}"
3593                ),
3594            });
3595        }
3596
3597        let (kd, kh, kw) = kernel_size;
3598        // PyTorch weight layout is [C_out, C_in / groups, kD, kH, kW] (`conv.py:171`).
3599        let mut weight = Parameter::zeros(&[out_channels, in_channels / groups, kd, kh, kw])?;
3600        kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
3601
3602        let bias_param = if bias {
3603            let mut b = Parameter::zeros(&[out_channels])?;
3604            // `torch/nn/modules/conv.py:198-201`: bias U(-bound, bound) with
3605            //   `bound = 1 / sqrt(fan_in)`, `fan_in = (in_channels/groups) * kD * kH * kW`.
3606            let fan_in = (in_channels / groups) * kd * kh * kw;
3607            let bound = if fan_in > 0 {
3608                1.0 / (fan_in as f64).sqrt()
3609            } else {
3610                0.0
3611            };
3612            uniform_init(&mut b, -bound, bound)?;
3613            Some(b)
3614        } else {
3615            None
3616        };
3617
3618        Ok(Self {
3619            weight,
3620            bias: bias_param,
3621            in_channels,
3622            out_channels,
3623            kernel_size,
3624            stride,
3625            padding,
3626            dilation,
3627            groups,
3628            padding_mode: crate::padding::PaddingMode::Zeros,
3629            string_padding: None,
3630            training: true,
3631        })
3632    }
3633
3634    /// Number of channel groups (`1` is dense, `in_channels` is depthwise).
3635    pub fn groups(&self) -> usize {
3636        self.groups
3637    }
3638
3639    /// Dilation `(dilD, dilH, dilW)` (`(1, 1, 1)` is the dense default).
3640    pub fn dilation(&self) -> (usize, usize, usize) {
3641        self.dilation
3642    }
3643
3644    /// Configure string padding (`'same'` / `'valid'`), mirroring the
3645    /// `padding: str` branch of `torch.nn.Conv3d` (`conv.py:111-155`).
3646    ///
3647    /// `StringPadding::Valid` is equivalent to `padding = 0`.
3648    /// `StringPadding::Same` pads so the output spatial size equals the input
3649    /// spatial size (for `stride = 1`), splitting each per-dim total
3650    /// `dilation * (kernel - 1)` asymmetrically as `left = total/2`,
3651    /// `right = total - left` (the END gets the extra unit; see
3652    /// [`same_pad_lr`]). The pre-pad uses the configured `padding_mode`
3653    /// (constant-0 for the default `Zeros`, matching `convolution_same`'s
3654    /// `constant_pad_nd(.., 0)`, `Convolution.cpp:1105`) and is autograd-aware
3655    /// via `Pad3dBackward`.
3656    ///
3657    /// Returns `Err` if `StringPadding::Same` is requested with a stride other
3658    /// than 1 in any dimension, matching upstream
3659    /// `raise ValueError("padding='same' is not supported for strided
3660    /// convolutions")` (`conv.py:117-120`, `Convolution.cpp:1071`). Closes
3661    /// #1602.
3662    pub fn with_string_padding(mut self, padding: StringPadding) -> FerrotorchResult<Self> {
3663        if padding == StringPadding::Same
3664            && (self.stride.0 != 1 || self.stride.1 != 1 || self.stride.2 != 1)
3665        {
3666            return Err(FerrotorchError::InvalidArgument {
3667                message: "padding='same' is not supported for strided convolutions".into(),
3668            });
3669        }
3670        self.string_padding = Some(padding);
3671        self.padding = (0, 0, 0);
3672        Ok(self)
3673    }
3674
3675    /// Configure the boundary handling for the spatial padding.
3676    ///
3677    /// `Zeros` (default) uses the existing im2col zero-pad path.
3678    /// `Reflect`, `Replicate`, and `Circular` pre-pad the input via
3679    /// `crate::padding::functional_pad_3d(input, ...)` and then convolve
3680    /// with `padding = 0`, matching `torch.nn.Conv3d(..., padding_mode=...)`
3681    /// (`Conv3d._conv_forward`'s `F.pad` shape, `conv.py:716-732`). The pad
3682    /// is autograd-aware (`Pad3dBackward`), so input gradients flow through
3683    /// the boundary. Closes #1443.
3684    pub fn with_padding_mode(mut self, mode: crate::padding::PaddingMode) -> Self {
3685        self.padding_mode = mode;
3686        self
3687    }
3688
3689    /// The number of learnable scalar parameters.
3690    pub fn num_parameters(&self) -> usize {
3691        let w = self.out_channels
3692            * self.in_channels
3693            * self.kernel_size.0
3694            * self.kernel_size.1
3695            * self.kernel_size.2;
3696        let b = if self.bias.is_some() {
3697            self.out_channels
3698        } else {
3699            0
3700        };
3701        w + b
3702    }
3703
3704    /// Build a `Conv3d` from caller-supplied weight and optional bias tensors.
3705    ///
3706    /// `weight` must have shape `[out_channels, in_channels, kD, kH, kW]`.
3707    /// The resulting layer is dense (`groups = 1`, `dilation = (1, 1, 1)`) so
3708    /// the constructor remains API-compatible with `nn::functional::conv3d`,
3709    /// which infers `in_channels = weight.shape()[1]` and cannot recover
3710    /// `groups` from the weight shape alone.
3711    pub fn from_parts(
3712        weight: Tensor<T>,
3713        bias: Option<Tensor<T>>,
3714        stride: (usize, usize, usize),
3715        padding: (usize, usize, usize),
3716    ) -> FerrotorchResult<Self> {
3717        if weight.ndim() != 5 {
3718            return Err(FerrotorchError::ShapeMismatch {
3719                message: format!(
3720                    "Conv3d::from_parts: weight must be 5-D [out, in, kD, kH, kW], got {:?}",
3721                    weight.shape()
3722                ),
3723            });
3724        }
3725        let out_channels = weight.shape()[0];
3726        let in_channels = weight.shape()[1];
3727        let kernel_size = (weight.shape()[2], weight.shape()[3], weight.shape()[4]);
3728        if let Some(b) = &bias {
3729            if b.ndim() != 1 || b.shape()[0] != out_channels {
3730                return Err(FerrotorchError::ShapeMismatch {
3731                    message: format!(
3732                        "Conv3d::from_parts: bias shape {:?} != [{}]",
3733                        b.shape(),
3734                        out_channels
3735                    ),
3736                });
3737            }
3738        }
3739        Ok(Self {
3740            weight: Parameter::new(weight),
3741            bias: bias.map(Parameter::new),
3742            in_channels,
3743            out_channels,
3744            kernel_size,
3745            stride,
3746            padding,
3747            dilation: (1, 1, 1),
3748            groups: 1,
3749            padding_mode: crate::padding::PaddingMode::Zeros,
3750            string_padding: None,
3751            training: true,
3752        })
3753    }
3754
3755    /// Build a shallow clone with the geometry overridden (used by `forward`
3756    /// to recurse onto the dense zero-padding im2col path after a
3757    /// string-padding / non-zero `padding_mode` pre-pad). `string_padding` is
3758    /// cleared so the recursion runs the numeric-padding path.
3759    fn recurse_clone(
3760        &self,
3761        padding: (usize, usize, usize),
3762        padding_mode: crate::padding::PaddingMode,
3763    ) -> Conv3d<T> {
3764        Conv3d {
3765            weight: Parameter::new(self.weight.tensor().clone()),
3766            bias: self
3767                .bias
3768                .as_ref()
3769                .map(|b| Parameter::new(b.tensor().clone())),
3770            in_channels: self.in_channels,
3771            out_channels: self.out_channels,
3772            kernel_size: self.kernel_size,
3773            stride: self.stride,
3774            padding,
3775            dilation: self.dilation,
3776            groups: self.groups,
3777            padding_mode,
3778            string_padding: None,
3779            training: self.training,
3780        }
3781    }
3782}
3783
3784impl<T: Float> Module<T> for Conv3d<T> {
3785    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3786        // Record autocast decision for conv3d.
3787        let _autocast_cat = autocast_guard("conv3d");
3788
3789        // Unbatched input: `(C, D, H, W)` (rank 4) is accepted in addition to
3790        // the batched `(N, C, D, H, W)` (rank 5) form. Mirrors `batchify` /
3791        // `_conv_forward` at `aten/src/ATen/native/Convolution.cpp:816-831,
3792        // 1040-1047`: an unbatched input is `unsqueeze(0)`d, convolved, then
3793        // `squeeze(0)`d so the output is rank 4. The unsqueeze/squeeze are
3794        // autograd-aware so gradients flow back to the unbatched shape. Closes
3795        // #1604.
3796        if input.ndim() == 4 {
3797            let batched = unsqueeze(input, 0)?;
3798            let output = self.forward(&batched)?;
3799            return squeeze(&output, 0);
3800        }
3801
3802        // String padding ('same' / 'valid'), mirroring the `padding: str`
3803        // branch of `torch.nn.Conv3d` (`conv.py:111-155`,
3804        // `Convolution.cpp:1119-1124`). `Valid` == numeric `padding = 0`;
3805        // `Same` pre-pads each spatial dim asymmetrically (`left = total/2`,
3806        // `right = total - left`) via the autograd-aware `functional_pad_3d`
3807        // then convolves with `padding = 0` — the `convolution_same` ->
3808        // `constant_pad_nd(.., 0)` path (`Convolution.cpp:1098-1108`).
3809        // `functional_pad_3d` takes amounts in `(W, H, D)` order
3810        // (left/right=W, top/bottom=H, front/back=D). The stride>1 rejection
3811        // happened at `with_string_padding` construction (`conv.py:117-120`).
3812        // Closes #1602.
3813        if let Some(sp) = self.string_padding {
3814            match sp {
3815                StringPadding::Valid => {
3816                    return self
3817                        .recurse_clone((0, 0, 0), crate::padding::PaddingMode::Zeros)
3818                        .forward(input);
3819                }
3820                StringPadding::Same => {
3821                    let (kd, kh, kw) = self.kernel_size;
3822                    let (dd, dh, dw) = self.dilation;
3823                    let (front, back) = same_pad_lr(kd, dd);
3824                    let (top, bottom) = same_pad_lr(kh, dh);
3825                    let (left, right) = same_pad_lr(kw, dw);
3826                    let padded = crate::padding::functional_pad_3d(
3827                        input,
3828                        left,
3829                        right,
3830                        top,
3831                        bottom,
3832                        front,
3833                        back,
3834                        self.padding_mode,
3835                        <T as num_traits::Zero>::zero(),
3836                    )?;
3837                    return self
3838                        .recurse_clone((0, 0, 0), crate::padding::PaddingMode::Zeros)
3839                        .forward(&padded);
3840                }
3841            }
3842        }
3843
3844        // Non-zero padding modes: pre-pad the input with the requested
3845        // boundary mode and then convolve with padding = 0. Mirrors
3846        // `torch/nn/modules/conv.py` `Conv3d._conv_forward` (`conv.py:716-732`):
3847        //   if self.padding_mode != 'zeros':
3848        //       F.conv3d(F.pad(input, self._reversed_padding_repeated_twice,
3849        //                      mode=self.padding_mode), ..., padding=_triple(0), ...)
3850        // For padding `(pd, ph, pw)`, `_reversed_padding_repeated_twice` is
3851        // `[pw, pw, ph, ph, pd, pd]` (`conv.py:157` reverses the per-dim order
3852        // because `F.pad` takes paddings in reverse-dim order). The 6-tuple maps
3853        // to `functional_pad_3d(left=pw, right=pw, top=ph, bottom=ph,
3854        // front=pd, back=pd)`. The pre-pad is autograd-aware (`Pad3dBackward`)
3855        // so input gradients flow through the boundary. Closes #1443.
3856        if self.padding_mode != crate::padding::PaddingMode::Zeros
3857            && (self.padding.0 != 0 || self.padding.1 != 0 || self.padding.2 != 0)
3858        {
3859            let (pd, ph, pw) = self.padding;
3860            let padded = crate::padding::functional_pad_3d(
3861                input,
3862                pw,
3863                pw,
3864                ph,
3865                ph,
3866                pd,
3867                pd,
3868                self.padding_mode,
3869                <T as num_traits::Zero>::zero(),
3870            )?;
3871            // Recurse on a zero-padding variant: build a shallow clone with
3872            // padding = (0,0,0) and padding_mode = Zeros so the existing
3873            // im2col-on-zero-pad path runs without re-padding.
3874            return self
3875                .recurse_clone((0, 0, 0), crate::padding::PaddingMode::Zeros)
3876                .forward(&padded);
3877        }
3878
3879        // Validate input shape: [B, C_in, D, H, W].
3880        if input.ndim() != 5 {
3881            return Err(FerrotorchError::InvalidArgument {
3882                message: format!(
3883                    "Conv3d expects 5-D input [B, C, D, H, W], got {:?}",
3884                    input.shape()
3885                ),
3886            });
3887        }
3888
3889        let batch = input.shape()[0];
3890        let c_in = input.shape()[1];
3891        let d = input.shape()[2];
3892        let h = input.shape()[3];
3893        let w = input.shape()[4];
3894
3895        if c_in != self.in_channels {
3896            return Err(FerrotorchError::ShapeMismatch {
3897                message: format!(
3898                    "Conv3d: expected {} input channels, got {}",
3899                    self.in_channels, c_in
3900                ),
3901            });
3902        }
3903
3904        let (kd, kh, kw) = self.kernel_size;
3905        let (sd, sh, sw) = self.stride;
3906        let (pd, ph, pw) = self.padding;
3907        let (dd, dh, dw) = self.dilation;
3908        let groups = self.groups;
3909
3910        // Effective kernel extent after dilation (`ConvUtils.h:255`).
3911        let eff_kd = dd * (kd - 1) + 1;
3912        let eff_kh = dh * (kh - 1) + 1;
3913        let eff_kw = dw * (kw - 1) + 1;
3914
3915        let d_padded = d + 2 * pd;
3916        let h_padded = h + 2 * ph;
3917        let w_padded = w + 2 * pw;
3918        if d_padded < eff_kd || h_padded < eff_kh || w_padded < eff_kw {
3919            return Err(FerrotorchError::InvalidArgument {
3920                message: format!(
3921                    "Conv3d: padded input ({d_padded}, {h_padded}, {w_padded}) is smaller than effective kernel ({eff_kd}, {eff_kh}, {eff_kw})"
3922                ),
3923            });
3924        }
3925
3926        let d_out = (d_padded - eff_kd) / sd + 1;
3927        let h_out = (h_padded - eff_kh) / sh + 1;
3928        let w_out = (w_padded - eff_kw) / sw + 1;
3929
3930        // Save the input device so we can restore it on the output.
3931        let input_device = input.device();
3932
3933        // ---- CPU path (dense, dilated, grouped, and grouped+dilated) ----
3934        // Partitions channels by `groups` exactly like Conv2d: each group's
3935        // input slice [B, in_per_group, D, H, W] is convolved with its weight
3936        // slice via the dilated 3-D im2col + GEMM and the outputs are stacked
3937        // along the C_out axis (mirroring `Convolution.cpp:1723-1729`).
3938        let input_data = input.data_vec()?;
3939        let weight_data = self.weight.data_vec()?;
3940
3941        let zero = <T as num_traits::Zero>::zero();
3942        let spatial_in = d * h * w;
3943        let spatial_out = d_out * h_out * w_out;
3944        let mut output = vec![zero; batch * self.out_channels * spatial_out];
3945
3946        // Per-group dimensions.
3947        let in_per_group = self.in_channels / groups;
3948        let out_per_group = self.out_channels / groups;
3949        let group_col_rows = in_per_group * kd * kh * kw;
3950        let weight_per_group_numel = out_per_group * group_col_rows;
3951        let col_cols = spatial_out;
3952
3953        // Saved im2col columns for autograd (dense channel layout
3954        // `[B, C_in * kD * kH * kW, D_out*H_out*W_out]`), so the backward
3955        // accumulates grad_input uniformly across groups (like Conv2dBackward).
3956        let saved_cols_rows = self.in_channels * kd * kh * kw;
3957        let mut saved_cols: Vec<T> = if is_grad_enabled()
3958            && (input.requires_grad()
3959                || self.weight.requires_grad()
3960                || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
3961        {
3962            vec![zero; batch * saved_cols_rows * col_cols]
3963        } else {
3964            Vec::new()
3965        };
3966        let save_cols = !saved_cols.is_empty();
3967        let kvol = kd * kh * kw;
3968
3969        for g in 0..groups {
3970            // Slice the input channels for this group: [B, in_per_group, D, H, W].
3971            let mut group_input = vec![zero; batch * in_per_group * spatial_in];
3972            for b in 0..batch {
3973                for c in 0..in_per_group {
3974                    let src_c = g * in_per_group + c;
3975                    let src_start = b * self.in_channels * spatial_in + src_c * spatial_in;
3976                    let dst_start = b * in_per_group * spatial_in + c * spatial_in;
3977                    group_input[dst_start..dst_start + spatial_in]
3978                        .copy_from_slice(&input_data[src_start..src_start + spatial_in]);
3979                }
3980            }
3981
3982            let (g_cols, g_col_rows, g_col_cols) = im2col_3d_dilated(
3983                &group_input,
3984                batch,
3985                in_per_group,
3986                d,
3987                h,
3988                w,
3989                kd,
3990                kh,
3991                kw,
3992                sd,
3993                sh,
3994                sw,
3995                pd,
3996                ph,
3997                pw,
3998                dd,
3999                dh,
4000                dw,
4001            );
4002            debug_assert_eq!(g_col_rows, group_col_rows);
4003            debug_assert_eq!(g_col_cols, col_cols);
4004
4005            // Save into the dense [C_in * kvol, col_cols] layout if needed.
4006            if save_cols {
4007                for b in 0..batch {
4008                    for c in 0..in_per_group {
4009                        let dst_c = g * in_per_group + c;
4010                        for kk in 0..kvol {
4011                            let src_row = c * kvol + kk;
4012                            let dst_row = dst_c * kvol + kk;
4013                            let src_off = b * group_col_rows * col_cols + src_row * col_cols;
4014                            let dst_off = b * saved_cols_rows * col_cols + dst_row * col_cols;
4015                            saved_cols[dst_off..dst_off + col_cols]
4016                                .copy_from_slice(&g_cols[src_off..src_off + col_cols]);
4017                        }
4018                    }
4019                }
4020            }
4021
4022            // Group's slice of the weight: [out_per_group, in_per_group, kD, kH, kW]
4023            // flattened to [out_per_group, group_col_rows].
4024            let w_group_start = g * weight_per_group_numel;
4025            let w_group_end = w_group_start + weight_per_group_numel;
4026            let weight_group_2d = Tensor::from_storage(
4027                TensorStorage::cpu(weight_data[w_group_start..w_group_end].to_vec()),
4028                vec![out_per_group, group_col_rows],
4029                false,
4030            )?;
4031
4032            for b in 0..batch {
4033                let col_start = b * group_col_rows * col_cols;
4034                let col_end = col_start + group_col_rows * col_cols;
4035                let cols_b = Tensor::from_storage(
4036                    TensorStorage::cpu(g_cols[col_start..col_end].to_vec()),
4037                    vec![group_col_rows, col_cols],
4038                    false,
4039                )?;
4040
4041                let out_b = mm(&weight_group_2d, &cols_b)?;
4042                let out_data = out_b.data()?;
4043                for oc in 0..out_per_group {
4044                    let dst_c = g * out_per_group + oc;
4045                    let dst_start = b * self.out_channels * spatial_out + dst_c * spatial_out;
4046                    let src_start = oc * spatial_out;
4047                    output[dst_start..dst_start + spatial_out]
4048                        .copy_from_slice(&out_data[src_start..src_start + spatial_out]);
4049                }
4050            }
4051        }
4052
4053        // Add bias if present: broadcast [C_out] over [B, C_out, D_out, H_out, W_out].
4054        if let Some(ref bias) = self.bias {
4055            let bias_data = bias.data_vec()?;
4056            for b in 0..batch {
4057                for c in 0..self.out_channels {
4058                    let bval = bias_data[c];
4059                    for s in 0..spatial_out {
4060                        output[b * self.out_channels * spatial_out + c * spatial_out + s] += bval;
4061                    }
4062                }
4063            }
4064        }
4065
4066        let result = Tensor::from_storage(
4067            TensorStorage::cpu(output),
4068            vec![batch, self.out_channels, d_out, h_out, w_out],
4069            false,
4070        )?;
4071
4072        // Attach backward if gradients are enabled and any input/param requires grad.
4073        if save_cols {
4074            let grad_fn = Arc::new(Conv3dBackward {
4075                input: input.clone(),
4076                weight: self.weight.tensor().clone(),
4077                bias: self.bias.as_ref().map(|b| b.tensor().clone()),
4078                in_channels: self.in_channels,
4079                out_channels: self.out_channels,
4080                kernel_size: self.kernel_size,
4081                stride: self.stride,
4082                padding: self.padding,
4083                dilation: self.dilation,
4084                groups: self.groups,
4085                cols: saved_cols,
4086                col_rows: saved_cols_rows,
4087                col_cols,
4088                d_out,
4089                h_out,
4090                w_out,
4091            });
4092            Tensor::from_operation(
4093                TensorStorage::cpu(result.data()?.to_vec()),
4094                result.shape().to_vec(),
4095                grad_fn,
4096            )?
4097            .to(input_device) // restore device
4098        } else {
4099            result.to(input_device)
4100        }
4101    }
4102
4103    fn parameters(&self) -> Vec<&Parameter<T>> {
4104        let mut params = vec![&self.weight];
4105        if let Some(ref b) = self.bias {
4106            params.push(b);
4107        }
4108        params
4109    }
4110
4111    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4112        let mut params = vec![&mut self.weight];
4113        if let Some(ref mut b) = self.bias {
4114            params.push(b);
4115        }
4116        params
4117    }
4118
4119    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4120        let mut params = vec![("weight".to_string(), &self.weight)];
4121        if let Some(ref b) = self.bias {
4122            params.push(("bias".to_string(), b));
4123        }
4124        params
4125    }
4126
4127    fn train(&mut self) {
4128        self.training = true;
4129    }
4130
4131    fn eval(&mut self) {
4132        self.training = false;
4133    }
4134
4135    fn is_training(&self) -> bool {
4136        self.training
4137    }
4138}
4139
4140// ---------------------------------------------------------------------------
4141// Conv3dBackward
4142// ---------------------------------------------------------------------------
4143
4144/// Backward function for `Conv3d` forward pass.
4145///
4146/// Saved tensors:
4147/// - `input`: the original 5-D input
4148/// - `weight`: the 5-D kernel `[C_out, C_in / groups, kD, kH, kW]`
4149/// - `bias`: optional 1-D bias
4150/// - `cols`: the dilated im2col_3d columns with **dense channel layout**
4151///   `[B, C_in * kD * kH * kW, D_out*H_out*W_out]` (saved regardless of
4152///   `groups`, so the backward takes the per-group slice at gradient time,
4153///   mirroring `Conv2dBackward`).
4154/// - `dilation`, `groups`: descriptors to reconstruct the per-group +
4155///   dilated math.
4156#[derive(Debug)]
4157struct Conv3dBackward<T: Float> {
4158    input: Tensor<T>,
4159    weight: Tensor<T>,
4160    bias: Option<Tensor<T>>,
4161    in_channels: usize,
4162    out_channels: usize,
4163    kernel_size: (usize, usize, usize),
4164    stride: (usize, usize, usize),
4165    padding: (usize, usize, usize),
4166    dilation: (usize, usize, usize),
4167    groups: usize,
4168    cols: Vec<T>,
4169    col_rows: usize,
4170    col_cols: usize,
4171    d_out: usize,
4172    h_out: usize,
4173    w_out: usize,
4174}
4175
4176impl<T: Float> GradFn<T> for Conv3dBackward<T> {
4177    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
4178        // grad_output shape: [B, C_out, D_out, H_out, W_out]
4179        let input_device = self.input.device();
4180        let weight_device = self.weight.device();
4181        let bias_device = self.bias.as_ref().map(|b| b.device());
4182        let go_data = grad_output.data_vec()?;
4183        let batch = self.input.shape()[0];
4184        let d = self.input.shape()[2];
4185        let h = self.input.shape()[3];
4186        let w = self.input.shape()[4];
4187        let (kd, kh, kw) = self.kernel_size;
4188        let (sd, sh, sw) = self.stride;
4189        let (pd, ph, pw) = self.padding;
4190        let (dd, dh, dw) = self.dilation;
4191        let groups = self.groups;
4192        let in_per_group = self.in_channels / groups;
4193        let out_per_group = self.out_channels / groups;
4194        let kvol = kd * kh * kw;
4195        let group_col_rows = in_per_group * kvol;
4196        let spatial_in = d * h * w;
4197        let spatial_out = self.d_out * self.h_out * self.w_out;
4198        let zero = <T as num_traits::Zero>::zero();
4199
4200        // --- grad_weight ---
4201        // Per group `g`: gw_g += grad_output_b_g @ cols_b_g^T, stacked along
4202        // the C_out axis to recover [C_out, C_in/G, kD, kH, kW]. Mirrors
4203        // Conv2dBackward.
4204        let grad_weight = if self.weight.requires_grad() {
4205            let weight_numel = self.out_channels * group_col_rows;
4206            let mut gw_accum = vec![zero; weight_numel];
4207            let weight_per_group_numel = out_per_group * group_col_rows;
4208
4209            for g in 0..groups {
4210                for b in 0..batch {
4211                    // Slice grad_output for this group: [out_per_group, spatial_out].
4212                    let mut go_g = vec![zero; out_per_group * spatial_out];
4213                    for oc in 0..out_per_group {
4214                        let src_c = g * out_per_group + oc;
4215                        let src_start = b * self.out_channels * spatial_out + src_c * spatial_out;
4216                        let dst_start = oc * spatial_out;
4217                        go_g[dst_start..dst_start + spatial_out]
4218                            .copy_from_slice(&go_data[src_start..src_start + spatial_out]);
4219                    }
4220                    let go_b_g = Tensor::from_storage(
4221                        TensorStorage::cpu(go_g),
4222                        vec![out_per_group, spatial_out],
4223                        false,
4224                    )?;
4225
4226                    // Slice cols for this group: [in_per_group * kvol, col_cols].
4227                    let mut cols_g = vec![zero; group_col_rows * self.col_cols];
4228                    for c in 0..in_per_group {
4229                        let src_c = g * in_per_group + c;
4230                        for kk in 0..kvol {
4231                            let src_row = src_c * kvol + kk;
4232                            let dst_row = c * kvol + kk;
4233                            let src_off =
4234                                b * self.col_rows * self.col_cols + src_row * self.col_cols;
4235                            let dst_off = dst_row * self.col_cols;
4236                            cols_g[dst_off..dst_off + self.col_cols]
4237                                .copy_from_slice(&self.cols[src_off..src_off + self.col_cols]);
4238                        }
4239                    }
4240                    let cols_b_g = Tensor::from_storage(
4241                        TensorStorage::cpu(cols_g),
4242                        vec![group_col_rows, self.col_cols],
4243                        false,
4244                    )?;
4245
4246                    let cols_bt = transpose(&cols_b_g)?;
4247                    let gw_b = mm(&go_b_g, &cols_bt)?;
4248                    let gw_data = gw_b.data()?;
4249
4250                    let dst_off = g * weight_per_group_numel;
4251                    for i in 0..weight_per_group_numel {
4252                        gw_accum[dst_off + i] += gw_data[i];
4253                    }
4254                }
4255            }
4256
4257            Some(
4258                Tensor::from_storage(
4259                    TensorStorage::cpu(gw_accum),
4260                    vec![self.out_channels, in_per_group, kd, kh, kw],
4261                    false,
4262                )?
4263                .to(weight_device)?,
4264            )
4265        } else {
4266            None
4267        };
4268
4269        // --- grad_bias ---
4270        // Sum grad_output over batch + spatial. Bias is per-output-channel,
4271        // identical for any groups setting.
4272        let grad_bias = match &self.bias {
4273            Some(b) if b.requires_grad() => {
4274                let mut gb = vec![zero; self.out_channels];
4275                for batch_idx in 0..batch {
4276                    for c in 0..self.out_channels {
4277                        for s in 0..spatial_out {
4278                            gb[c] += go_data
4279                                [batch_idx * self.out_channels * spatial_out + c * spatial_out + s];
4280                        }
4281                    }
4282                }
4283                let target_dev = bias_device.unwrap_or(input_device);
4284                Some(
4285                    Tensor::from_storage(TensorStorage::cpu(gb), vec![self.out_channels], false)?
4286                        .to(target_dev)?,
4287                )
4288            }
4289            _ => None,
4290        };
4291
4292        // --- grad_input ---
4293        // Per group `g`: weight_g_2d^T @ grad_output_b_g -> [in_per_group *
4294        // kvol, spatial_out], then col2im_3d_dilated -> [in_per_group, D, H, W]
4295        // placed into the right in-channel slice of [B, C_in, D, H, W].
4296        // Mirrors Conv2dBackward.
4297        let grad_input = if self.input.requires_grad() {
4298            let weight_data = self.weight.data_vec()?;
4299            let mut grad_input_data = vec![zero; batch * self.in_channels * spatial_in];
4300            let weight_per_group_numel = out_per_group * group_col_rows;
4301
4302            for g in 0..groups {
4303                let w_off = g * weight_per_group_numel;
4304                let weight_g_2d = Tensor::from_storage(
4305                    TensorStorage::cpu(weight_data[w_off..w_off + weight_per_group_numel].to_vec()),
4306                    vec![out_per_group, group_col_rows],
4307                    false,
4308                )?;
4309                let weight_g_t = transpose(&weight_g_2d)?;
4310
4311                let mut grad_cols_g = vec![zero; batch * group_col_rows * self.col_cols];
4312                for b in 0..batch {
4313                    let mut go_g = vec![zero; out_per_group * spatial_out];
4314                    for oc in 0..out_per_group {
4315                        let src_c = g * out_per_group + oc;
4316                        let src_start = b * self.out_channels * spatial_out + src_c * spatial_out;
4317                        let dst_start = oc * spatial_out;
4318                        go_g[dst_start..dst_start + spatial_out]
4319                            .copy_from_slice(&go_data[src_start..src_start + spatial_out]);
4320                    }
4321                    let go_b_g = Tensor::from_storage(
4322                        TensorStorage::cpu(go_g),
4323                        vec![out_per_group, spatial_out],
4324                        false,
4325                    )?;
4326
4327                    let gc_b = mm(&weight_g_t, &go_b_g)?;
4328                    let gc_data = gc_b.data()?;
4329                    let gc_start = b * group_col_rows * self.col_cols;
4330                    grad_cols_g[gc_start..gc_start + group_col_rows * self.col_cols]
4331                        .copy_from_slice(gc_data);
4332                }
4333
4334                // col2im_3d_dilated scatters group's columns back to
4335                // [B, in_per_group, D, H, W] honouring the dilation factors.
4336                let gi_g = col2im_3d_dilated(
4337                    &grad_cols_g,
4338                    batch,
4339                    in_per_group,
4340                    d,
4341                    h,
4342                    w,
4343                    kd,
4344                    kh,
4345                    kw,
4346                    sd,
4347                    sh,
4348                    sw,
4349                    pd,
4350                    ph,
4351                    pw,
4352                    dd,
4353                    dh,
4354                    dw,
4355                    self.d_out,
4356                    self.h_out,
4357                    self.w_out,
4358                );
4359
4360                for b in 0..batch {
4361                    for c in 0..in_per_group {
4362                        let dst_c = g * in_per_group + c;
4363                        let dst_start = b * self.in_channels * spatial_in + dst_c * spatial_in;
4364                        let src_start = b * in_per_group * spatial_in + c * spatial_in;
4365                        grad_input_data[dst_start..dst_start + spatial_in]
4366                            .copy_from_slice(&gi_g[src_start..src_start + spatial_in]);
4367                    }
4368                }
4369            }
4370
4371            Some(
4372                Tensor::from_storage(
4373                    TensorStorage::cpu(grad_input_data),
4374                    self.input.shape().to_vec(),
4375                    false,
4376                )?
4377                .to(input_device)?,
4378            )
4379        } else {
4380            None
4381        };
4382
4383        // Return exactly as many gradients as inputs() returns.
4384        let mut grads = vec![grad_input, grad_weight];
4385        if self.bias.is_some() {
4386            grads.push(grad_bias);
4387        }
4388        Ok(grads)
4389    }
4390
4391    fn inputs(&self) -> Vec<&Tensor<T>> {
4392        let mut v = vec![&self.input, &self.weight];
4393        if let Some(ref b) = self.bias {
4394            v.push(b);
4395        }
4396        v
4397    }
4398
4399    fn name(&self) -> &'static str {
4400        "Conv3dBackward"
4401    }
4402}
4403
4404// ---------------------------------------------------------------------------
4405// ConvTranspose1d
4406// ---------------------------------------------------------------------------
4407
4408/// A 1-D transposed convolution (deconvolution) layer.
4409///
4410/// Applies a transposed temporal convolution over an input `[B, C_in, L]`.
4411/// Used for upsampling in generative models and decoder networks.
4412/// Equivalent to `torch.nn.ConvTranspose1d`.
4413///
4414/// # Implementation
4415///
4416/// Delegates to the 2-D transposed convolution by adding a dummy spatial
4417/// dimension (H=1), then squeezes the output back to 3-D.
4418///
4419/// # Shape
4420///
4421/// - Input: `[B, in_channels, L]`
4422/// - Output: `[B, out_channels, L_out]`
4423///
4424/// where `L_out = (L - 1) * stride - 2 * padding + kernel_size + output_padding`.
4425#[derive(Debug)]
4426pub struct ConvTranspose1d<T: Float> {
4427    /// Learnable kernel weights `[in_channels, out_channels / groups, kernel_size]`.
4428    ///
4429    /// Note: the channel ordering is transposed compared to `Conv1d`
4430    /// (`in_channels` first), per `torch/nn/modules/conv.py:161-167`.
4431    weight: Parameter<T>,
4432    /// Optional learnable bias `[out_channels]`.
4433    bias: Option<Parameter<T>>,
4434    /// Number of input channels.
4435    in_channels: usize,
4436    /// Number of output channels.
4437    out_channels: usize,
4438    /// Kernel length.
4439    kernel_size: usize,
4440    /// Stride.
4441    stride: usize,
4442    /// Zero-padding removed from both sides of the output.
4443    padding: usize,
4444    /// Additional size added to one side of the output.
4445    output_padding: usize,
4446    /// Dilation. `1` is the dense default (`dilation` arg of
4447    /// `F.conv_transpose1d`, `torch/nn/modules/conv.py:1000-1009`).
4448    dilation: usize,
4449    /// Number of blocked input/output channel groups. `1` is dense. Must divide
4450    /// both `in_channels` and `out_channels`. Transposed weight is
4451    /// `[in_channels, out_channels / groups, K]`; the per-group partition
4452    /// mirrors `aten/src/ATen/native/Convolution.cpp:1723-1729`.
4453    groups: usize,
4454    /// Whether the module is in training mode.
4455    training: bool,
4456}
4457
4458impl<T: Float> ConvTranspose1d<T> {
4459    /// Create a new `ConvTranspose1d` layer (dense, dilation `1`, `groups = 1`).
4460    ///
4461    /// Weight is initialized with Kaiming uniform (ReLU gain).
4462    /// Bias, if enabled, is initialized U(-bound, bound) with
4463    /// `bound = 1/sqrt(fan_in)` per `torch/nn/modules/conv.py:198-201`.
4464    ///
4465    /// Thin shim over [`ConvTranspose1d::new_full`].
4466    pub fn new(
4467        in_channels: usize,
4468        out_channels: usize,
4469        kernel_size: usize,
4470        stride: usize,
4471        padding: usize,
4472        output_padding: usize,
4473        bias: bool,
4474    ) -> FerrotorchResult<Self> {
4475        Self::new_full(
4476            in_channels,
4477            out_channels,
4478            kernel_size,
4479            stride,
4480            padding,
4481            output_padding,
4482            1,
4483            1,
4484            bias,
4485        )
4486    }
4487
4488    /// Create a new `ConvTranspose1d` with the full PyTorch-shaped argument set,
4489    /// including `dilation` and `groups`.
4490    ///
4491    /// Mirrors `torch.nn.ConvTranspose1d.__init__` (`torch/nn/modules/conv.py:
4492    /// 944-978`). `groups` must divide BOTH `in_channels` and `out_channels`
4493    /// (`conv.py:105-110`). Transposed weight layout `[in_channels,
4494    /// out_channels / groups, K]` (`conv.py:161-167`).
4495    #[allow(clippy::too_many_arguments)]
4496    pub fn new_full(
4497        in_channels: usize,
4498        out_channels: usize,
4499        kernel_size: usize,
4500        stride: usize,
4501        padding: usize,
4502        output_padding: usize,
4503        dilation: usize,
4504        groups: usize,
4505        bias: bool,
4506    ) -> FerrotorchResult<Self> {
4507        if in_channels == 0 || out_channels == 0 {
4508            return Err(FerrotorchError::InvalidArgument {
4509                message: "in_channels and out_channels must be > 0".into(),
4510            });
4511        }
4512        if kernel_size == 0 {
4513            return Err(FerrotorchError::InvalidArgument {
4514                message: "kernel_size must be > 0".into(),
4515            });
4516        }
4517        if stride == 0 {
4518            return Err(FerrotorchError::InvalidArgument {
4519                message: "stride must be > 0".into(),
4520            });
4521        }
4522        if dilation == 0 {
4523            return Err(FerrotorchError::InvalidArgument {
4524                message: "dilation must be > 0".into(),
4525            });
4526        }
4527        if groups == 0 {
4528            return Err(FerrotorchError::InvalidArgument {
4529                message: "groups must be a positive integer".into(),
4530            });
4531        }
4532        if in_channels % groups != 0 {
4533            return Err(FerrotorchError::InvalidArgument {
4534                message: format!(
4535                    "in_channels ({in_channels}) must be divisible by groups ({groups})"
4536                ),
4537            });
4538        }
4539        if out_channels % groups != 0 {
4540            return Err(FerrotorchError::InvalidArgument {
4541                message: format!(
4542                    "out_channels ({out_channels}) must be divisible by groups ({groups})"
4543                ),
4544            });
4545        }
4546        if output_padding >= stride.max(dilation) {
4547            return Err(FerrotorchError::InvalidArgument {
4548                message: "output_padding must be strictly less than max(stride, dilation)".into(),
4549            });
4550        }
4551
4552        // Weight shape: [in_channels, out_channels / groups, K] (transposed layout).
4553        let out_per_group = out_channels / groups;
4554        let mut weight = Parameter::zeros(&[in_channels, out_per_group, kernel_size])?;
4555        kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
4556
4557        let bias_param = if bias {
4558            let mut b = Parameter::zeros(&[out_channels])?;
4559            // `torch/nn/modules/conv.py:198-201`: bias U(-bound, bound) with
4560            //   `bound = 1 / sqrt(fan_in)`. ConvTranspose1d weight shape
4561            //   `[in_channels, out_channels/groups, K]`, fan_in = (out/groups) * K.
4562            let fan_in = out_per_group * kernel_size;
4563            let bound = if fan_in > 0 {
4564                1.0 / (fan_in as f64).sqrt()
4565            } else {
4566                0.0
4567            };
4568            uniform_init(&mut b, -bound, bound)?;
4569            Some(b)
4570        } else {
4571            None
4572        };
4573
4574        Ok(Self {
4575            weight,
4576            bias: bias_param,
4577            in_channels,
4578            out_channels,
4579            kernel_size,
4580            stride,
4581            padding,
4582            output_padding,
4583            dilation,
4584            groups,
4585            training: true,
4586        })
4587    }
4588
4589    /// Configure the boundary handling for the spatial padding.
4590    ///
4591    /// Only [`crate::padding::PaddingMode::Zeros`] is accepted: upstream
4592    /// `_ConvTransposeNd.__init__` raises
4593    /// `ValueError('Only "zeros" padding mode is supported for ConvTranspose1d')`
4594    /// for any non-`zeros` mode (`torch/nn/modules/conv.py:755-758`). This
4595    /// matches that behaviour by returning an error rather than silently
4596    /// accepting the unsupported mode (R-DEV-2). The returned layer is
4597    /// unchanged (the only valid mode is `Zeros`, the constructed default).
4598    /// Closes #1443.
4599    pub fn with_padding_mode(self, mode: crate::padding::PaddingMode) -> FerrotorchResult<Self> {
4600        reject_non_zeros_transpose(mode, "ConvTranspose1d")?;
4601        Ok(self)
4602    }
4603
4604    /// The number of learnable scalar parameters.
4605    pub fn num_parameters(&self) -> usize {
4606        let w = self.in_channels * self.out_channels * self.kernel_size;
4607        let b = if self.bias.is_some() {
4608            self.out_channels
4609        } else {
4610            0
4611        };
4612        w + b
4613    }
4614
4615    /// Build a `ConvTranspose1d` from caller-supplied weight and optional bias.
4616    ///
4617    /// `weight` must have shape `[in_channels, out_channels, kernel_size]`
4618    /// (transposed channel ordering vs `Conv1d`). Used by
4619    /// `nn::functional::conv_transpose1d`.
4620    pub fn from_parts(
4621        weight: Tensor<T>,
4622        bias: Option<Tensor<T>>,
4623        stride: usize,
4624        padding: usize,
4625        output_padding: usize,
4626    ) -> FerrotorchResult<Self> {
4627        if weight.ndim() != 3 {
4628            return Err(FerrotorchError::ShapeMismatch {
4629                message: format!(
4630                    "ConvTranspose1d::from_parts: weight must be 3-D [in, out, k], got {:?}",
4631                    weight.shape()
4632                ),
4633            });
4634        }
4635        let in_channels = weight.shape()[0];
4636        let out_channels = weight.shape()[1];
4637        let kernel_size = weight.shape()[2];
4638        if output_padding >= stride {
4639            return Err(FerrotorchError::InvalidArgument {
4640                message: "output_padding must be strictly less than stride".into(),
4641            });
4642        }
4643        if let Some(b) = &bias {
4644            if b.ndim() != 1 || b.shape()[0] != out_channels {
4645                return Err(FerrotorchError::ShapeMismatch {
4646                    message: format!(
4647                        "ConvTranspose1d::from_parts: bias shape {:?} != [{}]",
4648                        b.shape(),
4649                        out_channels
4650                    ),
4651                });
4652            }
4653        }
4654        Ok(Self {
4655            weight: Parameter::new(weight),
4656            bias: bias.map(Parameter::new),
4657            in_channels,
4658            out_channels,
4659            kernel_size,
4660            stride,
4661            padding,
4662            output_padding,
4663            // Dense-only (groups=1, dilation=1); the group count is not
4664            // recoverable from the weight shape. Grouped/dilated go via
4665            // `new_full`. Matches the `Conv1d::from_parts` ABI policy.
4666            dilation: 1,
4667            groups: 1,
4668            training: true,
4669        })
4670    }
4671}
4672
4673impl<T: Float> Module<T> for ConvTranspose1d<T> {
4674    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
4675        // Record autocast decision for conv_transpose1d.
4676        let _autocast_cat = autocast_guard("conv_transpose1d");
4677
4678        // Unbatched input: `(C, L)` (rank 2) is accepted in addition to the
4679        // batched `(N, C, L)` (rank 3) form. Mirrors `batchify` at
4680        // `aten/src/ATen/native/Convolution.cpp:1178` (conv_transpose1d):
4681        // `unsqueeze(0)` -> transposed-conv -> `squeeze(0)`, autograd-aware.
4682        // Closes #1609.
4683        if input.ndim() == 2 {
4684            let batched = unsqueeze(input, 0)?;
4685            let output = self.forward(&batched)?;
4686            return squeeze(&output, 0);
4687        }
4688
4689        // Validate input shape: [B, C_in, L].
4690        if input.ndim() != 3 {
4691            return Err(FerrotorchError::InvalidArgument {
4692                message: format!(
4693                    "ConvTranspose1d expects 3-D input [B, C, L], got {:?}",
4694                    input.shape()
4695                ),
4696            });
4697        }
4698
4699        let batch = input.shape()[0];
4700        let c_in = input.shape()[1];
4701        let length = input.shape()[2];
4702
4703        if c_in != self.in_channels {
4704            return Err(FerrotorchError::ShapeMismatch {
4705                message: format!(
4706                    "ConvTranspose1d: expected {} input channels, got {}",
4707                    self.in_channels, c_in
4708                ),
4709            });
4710        }
4711
4712        let k = self.kernel_size;
4713        let groups = self.groups;
4714        let in_pg = self.in_channels / groups;
4715        let out_pg = self.out_channels / groups;
4716        let weight_pg_numel = in_pg * out_pg * k;
4717
4718        // Save the input device so we can restore it on the output.
4719        let input_device = input.device();
4720
4721        let input_data = input.data_vec()?;
4722        let weight_data = self.weight.data_vec()?;
4723
4724        // Per-group transposed convolution (Convolution.cpp:1723-1729). 1-D is
4725        // the 2-D group helper with `H = 1`: kernel `(1, k)`, stride `(1, s)`,
4726        // padding `(0, p)`, output_padding `(0, op)`, dilation `(1, dilation)`.
4727        // Weight slab is `[in_pg, out_pg, K]` reshaped to `[in_pg, out_pg, 1,
4728        // K]` for the helper.
4729        let zero = <T as num_traits::Zero>::zero();
4730        let mut output: Vec<T> = Vec::new();
4731        let mut l_out = 0usize;
4732
4733        for g in 0..groups {
4734            let mut group_input = vec![zero; batch * in_pg * length];
4735            for b in 0..batch {
4736                for c in 0..in_pg {
4737                    let src_c = g * in_pg + c;
4738                    let src = (b * self.in_channels + src_c) * length;
4739                    let dst = (b * in_pg + c) * length;
4740                    group_input[dst..dst + length].copy_from_slice(&input_data[src..src + length]);
4741                }
4742            }
4743
4744            let w_start = g * weight_pg_numel;
4745            let group_weight = &weight_data[w_start..w_start + weight_pg_numel];
4746
4747            let (g_out, gho, glo) = conv_transpose2d_forward_group(
4748                &group_input,
4749                batch,
4750                in_pg,
4751                out_pg,
4752                1,
4753                length,
4754                (1, k),
4755                (1, self.stride),
4756                (0, self.padding),
4757                (0, self.output_padding),
4758                (1, self.dilation),
4759                group_weight,
4760            )?;
4761            debug_assert_eq!(gho, 1);
4762            l_out = glo;
4763
4764            if output.is_empty() {
4765                output = vec![zero; batch * self.out_channels * l_out];
4766            }
4767            for b in 0..batch {
4768                for oc in 0..out_pg {
4769                    let dst_c = g * out_pg + oc;
4770                    let dst = (b * self.out_channels + dst_c) * l_out;
4771                    let src = (b * out_pg + oc) * l_out;
4772                    output[dst..dst + l_out].copy_from_slice(&g_out[src..src + l_out]);
4773                }
4774            }
4775        }
4776
4777        // Add bias if present.
4778        if let Some(ref bias) = self.bias {
4779            let bias_data = bias.data_vec()?;
4780            for b in 0..batch {
4781                for c in 0..self.out_channels {
4782                    let bval = bias_data[c];
4783                    for l in 0..l_out {
4784                        output[b * self.out_channels * l_out + c * l_out + l] += bval;
4785                    }
4786                }
4787            }
4788        }
4789
4790        let result = Tensor::from_storage(
4791            TensorStorage::cpu(output),
4792            vec![batch, self.out_channels, l_out],
4793            false,
4794        )?;
4795
4796        // Attach backward if gradients are enabled.
4797        if is_grad_enabled()
4798            && (input.requires_grad()
4799                || self.weight.requires_grad()
4800                || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
4801        {
4802            let grad_fn = Arc::new(ConvTranspose1dBackward {
4803                input: input.clone(),
4804                weight: self.weight.tensor().clone(),
4805                bias: self.bias.as_ref().map(|b| b.tensor().clone()),
4806                in_channels: self.in_channels,
4807                out_channels: self.out_channels,
4808                kernel_size: self.kernel_size,
4809                stride: self.stride,
4810                padding: self.padding,
4811                _output_padding: self.output_padding,
4812                dilation: self.dilation,
4813                groups: self.groups,
4814                l_out,
4815            });
4816            Tensor::from_operation(
4817                TensorStorage::cpu(result.data()?.to_vec()),
4818                result.shape().to_vec(),
4819                grad_fn,
4820            )?
4821            .to(input_device) // restore device
4822        } else {
4823            result.to(input_device)
4824        }
4825    }
4826
4827    fn parameters(&self) -> Vec<&Parameter<T>> {
4828        let mut params = vec![&self.weight];
4829        if let Some(ref b) = self.bias {
4830            params.push(b);
4831        }
4832        params
4833    }
4834
4835    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
4836        let mut params = vec![&mut self.weight];
4837        if let Some(ref mut b) = self.bias {
4838            params.push(b);
4839        }
4840        params
4841    }
4842
4843    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
4844        let mut params = vec![("weight".to_string(), &self.weight)];
4845        if let Some(ref b) = self.bias {
4846            params.push(("bias".to_string(), b));
4847        }
4848        params
4849    }
4850
4851    fn train(&mut self) {
4852        self.training = true;
4853    }
4854
4855    fn eval(&mut self) {
4856        self.training = false;
4857    }
4858
4859    fn is_training(&self) -> bool {
4860        self.training
4861    }
4862}
4863
4864// ---------------------------------------------------------------------------
4865// ConvTranspose1dBackward
4866// ---------------------------------------------------------------------------
4867
4868/// Backward function for `ConvTranspose1d` forward pass.
4869///
4870/// The backward of a transposed convolution is a regular convolution.
4871#[derive(Debug)]
4872struct ConvTranspose1dBackward<T: Float> {
4873    input: Tensor<T>,
4874    weight: Tensor<T>,
4875    bias: Option<Tensor<T>>,
4876    in_channels: usize,
4877    out_channels: usize,
4878    kernel_size: usize,
4879    stride: usize,
4880    padding: usize,
4881    _output_padding: usize,
4882    dilation: usize,
4883    groups: usize,
4884    l_out: usize,
4885}
4886
4887impl<T: Float> GradFn<T> for ConvTranspose1dBackward<T> {
4888    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
4889        // grad_output shape: [B, C_out, L_out]
4890        let go_data = grad_output.data_vec()?;
4891        let batch = self.input.shape()[0];
4892        let l_in = self.input.shape()[2];
4893        let k = self.kernel_size;
4894        let s = self.stride;
4895        let p = self.padding;
4896        let d = self.dilation;
4897        let groups = self.groups;
4898        let in_pg = self.in_channels / groups;
4899        let out_pg = self.out_channels / groups;
4900        let zero = <T as num_traits::Zero>::zero();
4901
4902        let weight_data_all = self.weight.data_vec()?;
4903        let input_data_all = self.input.data_vec()?;
4904
4905        // Per-group, dilation-spaced backward (1-D is the 2-D W axis).
4906        let mut gi_all = if self.input.requires_grad() {
4907            Some(vec![zero; batch * self.in_channels * l_in])
4908        } else {
4909            None
4910        };
4911        let mut gw_all = if self.weight.requires_grad() {
4912            Some(vec![zero; self.in_channels * out_pg * k])
4913        } else {
4914            None
4915        };
4916
4917        for g in 0..groups {
4918            // --- grad_input (group g) ---
4919            if let Some(gi) = gi_all.as_mut() {
4920                let col_rows = out_pg * k;
4921                let w_start = g * in_pg * out_pg * k;
4922                let weight_2d = Tensor::from_storage(
4923                    TensorStorage::cpu(
4924                        weight_data_all[w_start..w_start + in_pg * out_pg * k].to_vec(),
4925                    ),
4926                    vec![in_pg, col_rows],
4927                    false,
4928                )?;
4929
4930                let mut go_g = vec![zero; batch * out_pg * self.l_out];
4931                for b in 0..batch {
4932                    for c in 0..out_pg {
4933                        let src_c = g * out_pg + c;
4934                        let src = (b * self.out_channels + src_c) * self.l_out;
4935                        let dst = (b * out_pg + c) * self.l_out;
4936                        go_g[dst..dst + self.l_out]
4937                            .copy_from_slice(&go_data[src..src + self.l_out]);
4938                    }
4939                }
4940
4941                // im2col on grad_output [B, out_pg, 1, L_out] with kernel (1, k),
4942                // stride (1, s), padding (0, p), dilation (1, d).
4943                let (go_cols, _gcr, go_col_cols) =
4944                    im2col_dilated(&go_g, batch, out_pg, 1, self.l_out, 1, k, 1, s, 0, p, 1, d);
4945                debug_assert_eq!(go_col_cols, l_in);
4946
4947                for b in 0..batch {
4948                    let col_start = b * col_rows * go_col_cols;
4949                    let col_end = col_start + col_rows * go_col_cols;
4950                    let go_cols_b = Tensor::from_storage(
4951                        TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
4952                        vec![col_rows, go_col_cols],
4953                        false,
4954                    )?;
4955                    let gi_b = mm(&weight_2d, &go_cols_b)?;
4956                    let gi_data = gi_b.data()?;
4957                    for c in 0..in_pg {
4958                        let dst_c = g * in_pg + c;
4959                        let dst = (b * self.in_channels + dst_c) * l_in;
4960                        let src = c * l_in;
4961                        gi[dst..dst + l_in].copy_from_slice(&gi_data[src..src + l_in]);
4962                    }
4963                }
4964            }
4965
4966            // --- grad_weight (group g) ---
4967            if let Some(gw) = gw_all.as_mut() {
4968                for ci in 0..in_pg {
4969                    let in_c = g * in_pg + ci;
4970                    for co in 0..out_pg {
4971                        let out_c = g * out_pg + co;
4972                        for tk in 0..k {
4973                            let mut acc = zero;
4974                            for il in 0..l_in {
4975                                let ow = il * s + tk * d;
4976                                if ow >= p && (ow - p) < self.l_out {
4977                                    let go_index = out_c * self.l_out + (ow - p);
4978                                    let in_index = in_c * l_in + il;
4979                                    for b in 0..batch {
4980                                        let goi = b * self.out_channels * self.l_out + go_index;
4981                                        let ini = b * self.in_channels * l_in + in_index;
4982                                        acc += input_data_all[ini] * go_data[goi];
4983                                    }
4984                                }
4985                            }
4986                            // gw layout: [in_channels, out_pg, K].
4987                            gw[(in_c * out_pg + co) * k + tk] += acc;
4988                        }
4989                    }
4990                }
4991            }
4992        }
4993
4994        let grad_input = match gi_all {
4995            Some(gi) => Some(Tensor::from_storage(
4996                TensorStorage::cpu(gi),
4997                self.input.shape().to_vec(),
4998                false,
4999            )?),
5000            None => None,
5001        };
5002        let grad_weight = match gw_all {
5003            Some(gw) => Some(Tensor::from_storage(
5004                TensorStorage::cpu(gw),
5005                vec![self.in_channels, out_pg, k],
5006                false,
5007            )?),
5008            None => None,
5009        };
5010
5011        // --- grad_bias ---
5012        let grad_bias = match &self.bias {
5013            Some(b) if b.requires_grad() => {
5014                let zero = <T as num_traits::Zero>::zero();
5015                let mut gb = vec![zero; self.out_channels];
5016                for batch_idx in 0..batch {
5017                    for c in 0..self.out_channels {
5018                        for l in 0..self.l_out {
5019                            gb[c] += go_data
5020                                [batch_idx * self.out_channels * self.l_out + c * self.l_out + l];
5021                        }
5022                    }
5023                }
5024                Some(Tensor::from_storage(
5025                    TensorStorage::cpu(gb),
5026                    vec![self.out_channels],
5027                    false,
5028                )?)
5029            }
5030            _ => None,
5031        };
5032
5033        let mut grads = vec![grad_input, grad_weight];
5034        if self.bias.is_some() {
5035            grads.push(grad_bias);
5036        }
5037        Ok(grads)
5038    }
5039
5040    fn inputs(&self) -> Vec<&Tensor<T>> {
5041        let mut v = vec![&self.input, &self.weight];
5042        if let Some(ref b) = self.bias {
5043            v.push(b);
5044        }
5045        v
5046    }
5047
5048    fn name(&self) -> &'static str {
5049        "ConvTranspose1dBackward"
5050    }
5051}
5052
5053// ---------------------------------------------------------------------------
5054// ConvTranspose3d
5055// ---------------------------------------------------------------------------
5056
5057/// A 3-D transposed convolution (deconvolution) layer.
5058///
5059/// Applies a transposed volumetric convolution over an input `[B, C_in, D, H, W]`.
5060/// Used for upsampling in generative models and 3-D decoder networks.
5061/// Equivalent to `torch.nn.ConvTranspose3d`.
5062///
5063/// # Implementation
5064///
5065/// The forward pass inserts `(stride - 1)` zeros between each input element
5066/// along all three spatial axes (fractionally-strided convolution), then applies
5067/// a standard 3-D convolution with the kernel flipped along all spatial axes.
5068///
5069/// # Shape
5070///
5071/// - Input: `[B, in_channels, D, H, W]`
5072/// - Output: `[B, out_channels, D_out, H_out, W_out]`
5073///
5074/// where `D_out = (D - 1) * stride.0 - 2 * padding.0 + kernel_size.0 + output_padding.0`
5075/// (and analogously for H and W).
5076#[derive(Debug)]
5077pub struct ConvTranspose3d<T: Float> {
5078    /// Learnable kernel weights `[in_channels, out_channels / groups, kD, kH, kW]`.
5079    ///
5080    /// Note: the channel ordering is transposed compared to `Conv3d`
5081    /// (`in_channels` first), per `torch/nn/modules/conv.py:161-167`.
5082    weight: Parameter<T>,
5083    /// Optional learnable bias `[out_channels]`.
5084    bias: Option<Parameter<T>>,
5085    /// Number of input channels.
5086    in_channels: usize,
5087    /// Number of output channels.
5088    out_channels: usize,
5089    /// Kernel spatial size `(kD, kH, kW)`.
5090    kernel_size: (usize, usize, usize),
5091    /// Stride `(sD, sH, sW)`.
5092    stride: (usize, usize, usize),
5093    /// Zero-padding `(pD, pH, pW)` removed from both sides of the output.
5094    padding: (usize, usize, usize),
5095    /// Additional size added to one side of the output `(opD, opH, opW)`.
5096    output_padding: (usize, usize, usize),
5097    /// Dilation `(dilD, dilH, dilW)`. `(1, 1, 1)` is the dense default
5098    /// (`dilation` arg of `F.conv_transpose3d`).
5099    dilation: (usize, usize, usize),
5100    /// Number of blocked input/output channel groups. `1` is dense. Must divide
5101    /// both `in_channels` and `out_channels`. Transposed weight
5102    /// `[in_channels, out_channels / groups, kD, kH, kW]`; per-group partition
5103    /// mirrors `aten/src/ATen/native/Convolution.cpp:1723-1729`.
5104    groups: usize,
5105    /// Whether the module is in training mode.
5106    training: bool,
5107}
5108
5109impl<T: Float> ConvTranspose3d<T> {
5110    /// Create a new `ConvTranspose3d` layer (dense, dilation `(1, 1, 1)`,
5111    /// `groups = 1`).
5112    ///
5113    /// Weight is initialized with Kaiming uniform (ReLU gain).
5114    /// Bias, if enabled, is initialized U(-bound, bound) with
5115    /// `bound = 1/sqrt(fan_in)` per `torch/nn/modules/conv.py:198-201`.
5116    ///
5117    /// Thin shim over [`ConvTranspose3d::new_full`].
5118    pub fn new(
5119        in_channels: usize,
5120        out_channels: usize,
5121        kernel_size: (usize, usize, usize),
5122        stride: (usize, usize, usize),
5123        padding: (usize, usize, usize),
5124        output_padding: (usize, usize, usize),
5125        bias: bool,
5126    ) -> FerrotorchResult<Self> {
5127        Self::new_full(
5128            in_channels,
5129            out_channels,
5130            kernel_size,
5131            stride,
5132            padding,
5133            output_padding,
5134            (1, 1, 1),
5135            1,
5136            bias,
5137        )
5138    }
5139
5140    /// Create a new `ConvTranspose3d` with the full PyTorch-shaped argument set,
5141    /// including `dilation` and `groups`.
5142    ///
5143    /// Mirrors `torch.nn.ConvTranspose3d.__init__` (`torch/nn/modules/conv.py:
5144    /// 1325-1356`). `groups` must divide BOTH `in_channels` and `out_channels`
5145    /// (`conv.py:105-110`). Transposed weight layout `[in_channels,
5146    /// out_channels / groups, kD, kH, kW]` (`conv.py:161-167`).
5147    #[allow(clippy::too_many_arguments)]
5148    pub fn new_full(
5149        in_channels: usize,
5150        out_channels: usize,
5151        kernel_size: (usize, usize, usize),
5152        stride: (usize, usize, usize),
5153        padding: (usize, usize, usize),
5154        output_padding: (usize, usize, usize),
5155        dilation: (usize, usize, usize),
5156        groups: usize,
5157        bias: bool,
5158    ) -> FerrotorchResult<Self> {
5159        if in_channels == 0 || out_channels == 0 {
5160            return Err(FerrotorchError::InvalidArgument {
5161                message: "in_channels and out_channels must be > 0".into(),
5162            });
5163        }
5164        if kernel_size.0 == 0 || kernel_size.1 == 0 || kernel_size.2 == 0 {
5165            return Err(FerrotorchError::InvalidArgument {
5166                message: "kernel_size must be > 0 in all dimensions".into(),
5167            });
5168        }
5169        if stride.0 == 0 || stride.1 == 0 || stride.2 == 0 {
5170            return Err(FerrotorchError::InvalidArgument {
5171                message: "stride must be > 0 in all dimensions".into(),
5172            });
5173        }
5174        if dilation.0 == 0 || dilation.1 == 0 || dilation.2 == 0 {
5175            return Err(FerrotorchError::InvalidArgument {
5176                message: "dilation must be > 0 in all dimensions".into(),
5177            });
5178        }
5179        if groups == 0 {
5180            return Err(FerrotorchError::InvalidArgument {
5181                message: "groups must be a positive integer".into(),
5182            });
5183        }
5184        if in_channels % groups != 0 {
5185            return Err(FerrotorchError::InvalidArgument {
5186                message: format!(
5187                    "in_channels ({in_channels}) must be divisible by groups ({groups})"
5188                ),
5189            });
5190        }
5191        if out_channels % groups != 0 {
5192            return Err(FerrotorchError::InvalidArgument {
5193                message: format!(
5194                    "out_channels ({out_channels}) must be divisible by groups ({groups})"
5195                ),
5196            });
5197        }
5198        if output_padding.0 >= stride.0.max(dilation.0)
5199            || output_padding.1 >= stride.1.max(dilation.1)
5200            || output_padding.2 >= stride.2.max(dilation.2)
5201        {
5202            return Err(FerrotorchError::InvalidArgument {
5203                message:
5204                    "output_padding must be strictly less than max(stride, dilation) in all dimensions"
5205                        .into(),
5206            });
5207        }
5208
5209        // Weight shape: [in_channels, out_channels / groups, kD, kH, kW] (transposed layout).
5210        let (kd, kh, kw) = kernel_size;
5211        let out_per_group = out_channels / groups;
5212        let mut weight = Parameter::zeros(&[in_channels, out_per_group, kd, kh, kw])?;
5213        kaiming_uniform(&mut weight, NonLinearity::ReLU)?;
5214
5215        let bias_param = if bias {
5216            let mut b = Parameter::zeros(&[out_channels])?;
5217            // `torch/nn/modules/conv.py:198-201`: bias U(-bound, bound) with
5218            //   `bound = 1 / sqrt(fan_in)`, fan_in = (out/groups) * kD*kH*kW.
5219            let fan_in = out_per_group * kd * kh * kw;
5220            let bound = if fan_in > 0 {
5221                1.0 / (fan_in as f64).sqrt()
5222            } else {
5223                0.0
5224            };
5225            uniform_init(&mut b, -bound, bound)?;
5226            Some(b)
5227        } else {
5228            None
5229        };
5230
5231        Ok(Self {
5232            weight,
5233            bias: bias_param,
5234            in_channels,
5235            out_channels,
5236            kernel_size,
5237            stride,
5238            padding,
5239            output_padding,
5240            dilation,
5241            groups,
5242            training: true,
5243        })
5244    }
5245
5246    /// Configure the boundary handling for the spatial padding.
5247    ///
5248    /// Only [`crate::padding::PaddingMode::Zeros`] is accepted: upstream
5249    /// `_ConvTransposeNd.__init__` raises
5250    /// `ValueError('Only "zeros" padding mode is supported for ConvTranspose3d')`
5251    /// for any non-`zeros` mode (`torch/nn/modules/conv.py:755-758`). This
5252    /// matches that behaviour by returning an error rather than silently
5253    /// accepting the unsupported mode (R-DEV-2). The returned layer is
5254    /// unchanged (the only valid mode is `Zeros`, the constructed default).
5255    /// Closes #1443.
5256    pub fn with_padding_mode(self, mode: crate::padding::PaddingMode) -> FerrotorchResult<Self> {
5257        reject_non_zeros_transpose(mode, "ConvTranspose3d")?;
5258        Ok(self)
5259    }
5260
5261    /// The number of learnable scalar parameters.
5262    pub fn num_parameters(&self) -> usize {
5263        let w = self.in_channels
5264            * self.out_channels
5265            * self.kernel_size.0
5266            * self.kernel_size.1
5267            * self.kernel_size.2;
5268        let b = if self.bias.is_some() {
5269            self.out_channels
5270        } else {
5271            0
5272        };
5273        w + b
5274    }
5275
5276    /// Build a `ConvTranspose3d` from caller-supplied weight and optional bias.
5277    ///
5278    /// `weight` must have shape `[in_channels, out_channels, kD, kH, kW]`
5279    /// (transposed channel ordering vs `Conv3d`). Used by
5280    /// `nn::functional::conv_transpose3d`.
5281    pub fn from_parts(
5282        weight: Tensor<T>,
5283        bias: Option<Tensor<T>>,
5284        stride: (usize, usize, usize),
5285        padding: (usize, usize, usize),
5286        output_padding: (usize, usize, usize),
5287    ) -> FerrotorchResult<Self> {
5288        if weight.ndim() != 5 {
5289            return Err(FerrotorchError::ShapeMismatch {
5290                message: format!(
5291                    "ConvTranspose3d::from_parts: weight must be 5-D [in, out, kD, kH, kW], got {:?}",
5292                    weight.shape()
5293                ),
5294            });
5295        }
5296        let in_channels = weight.shape()[0];
5297        let out_channels = weight.shape()[1];
5298        let kernel_size = (weight.shape()[2], weight.shape()[3], weight.shape()[4]);
5299        if output_padding.0 >= stride.0
5300            || output_padding.1 >= stride.1
5301            || output_padding.2 >= stride.2
5302        {
5303            return Err(FerrotorchError::InvalidArgument {
5304                message: "output_padding must be strictly less than stride in all dimensions"
5305                    .into(),
5306            });
5307        }
5308        if let Some(b) = &bias {
5309            if b.ndim() != 1 || b.shape()[0] != out_channels {
5310                return Err(FerrotorchError::ShapeMismatch {
5311                    message: format!(
5312                        "ConvTranspose3d::from_parts: bias shape {:?} != [{}]",
5313                        b.shape(),
5314                        out_channels
5315                    ),
5316                });
5317            }
5318        }
5319        Ok(Self {
5320            weight: Parameter::new(weight),
5321            bias: bias.map(Parameter::new),
5322            in_channels,
5323            out_channels,
5324            kernel_size,
5325            stride,
5326            padding,
5327            output_padding,
5328            // Dense-only (groups=1, dilation=1) ABI policy, mirroring
5329            // `Conv3d::from_parts`. Grouped/dilated go via `new_full`.
5330            dilation: (1, 1, 1),
5331            groups: 1,
5332            training: true,
5333        })
5334    }
5335}
5336
5337/// Insert `(stride - 1)` zeros between each element along three spatial axes.
5338///
5339/// Given input `[B, C, D, H, W]`, produces `[B, C, D_up, H_up, W_up]` where
5340/// `D_up = (D - 1) * stride_d + 1` (and analogously for H, W).
5341// Internal kernel for ConvTranspose3d backward: arguments are the 3-D
5342// shape descriptor + per-axis stride; refactoring to a config struct would
5343// add allocation in a hot path.
5344#[allow(clippy::too_many_arguments)]
5345fn stride_insert_zeros_3d<T: Float>(
5346    input: &[T],
5347    batch: usize,
5348    channels: usize,
5349    d: usize,
5350    h: usize,
5351    w: usize,
5352    stride_d: usize,
5353    stride_h: usize,
5354    stride_w: usize,
5355) -> (Vec<T>, usize, usize, usize) {
5356    let d_up = (d - 1) * stride_d + 1;
5357    let h_up = (h - 1) * stride_h + 1;
5358    let w_up = (w - 1) * stride_w + 1;
5359    let zero = <T as num_traits::Zero>::zero();
5360    let mut out = vec![zero; batch * channels * d_up * h_up * w_up];
5361
5362    for b in 0..batch {
5363        for c in 0..channels {
5364            for id in 0..d {
5365                for ih in 0..h {
5366                    for iw in 0..w {
5367                        let od = id * stride_d;
5368                        let oh = ih * stride_h;
5369                        let ow = iw * stride_w;
5370                        out[b * channels * d_up * h_up * w_up
5371                            + c * d_up * h_up * w_up
5372                            + od * h_up * w_up
5373                            + oh * w_up
5374                            + ow] = input
5375                            [b * channels * d * h * w + c * d * h * w + id * h * w + ih * w + iw];
5376                    }
5377                }
5378            }
5379        }
5380    }
5381
5382    (out, d_up, h_up, w_up)
5383}
5384
5385/// Crop a `[batch, channels, D, H, W]` volume by `crop_*` elements off BOTH
5386/// ends of each spatial axis, returning the cropped buffer plus its new
5387/// spatial extents. Used by the transposed-conv forward when the internal
5388/// padding `dilation*(k-1) - padding` is negative (i.e. `padding >
5389/// dilation*(k-1)`): a negative internal pad means the upsampled signal must
5390/// be trimmed rather than zero-padded before the stride-1 internal
5391/// convolution, mirroring upstream's output-extent-bounded `col2vol` scatter
5392/// (`aten/src/ATen/native/vol2col.h:80-106`). Callers guarantee
5393/// `2*crop_* < extent` (the transposed output extent is otherwise non-
5394/// positive, which `new_full`'s construction-time checks already reject).
5395// Internal kernel: the descriptor mirrors a 3-D volume layout; a config struct
5396// would force allocation in the per-group hot loop.
5397#[allow(clippy::too_many_arguments)]
5398fn crop_volume_3d<T: Float>(
5399    input: &[T],
5400    batch: usize,
5401    channels: usize,
5402    d: usize,
5403    h: usize,
5404    w: usize,
5405    crop_d: usize,
5406    crop_h: usize,
5407    crop_w: usize,
5408) -> (Vec<T>, usize, usize, usize) {
5409    let d_out = d - 2 * crop_d;
5410    let h_out = h - 2 * crop_h;
5411    let w_out = w - 2 * crop_w;
5412    let zero = <T as num_traits::Zero>::zero();
5413    let mut out = vec![zero; batch * channels * d_out * h_out * w_out];
5414
5415    for b in 0..batch {
5416        for c in 0..channels {
5417            for od in 0..d_out {
5418                for oh in 0..h_out {
5419                    let src =
5420                        (((b * channels + c) * d + (od + crop_d)) * h + (oh + crop_h)) * w + crop_w;
5421                    let dst = (((b * channels + c) * d_out + od) * h_out + oh) * w_out;
5422                    out[dst..dst + w_out].copy_from_slice(&input[src..src + w_out]);
5423                }
5424            }
5425        }
5426    }
5427
5428    (out, d_out, h_out, w_out)
5429}
5430
5431/// Flip a 3-D kernel along all spatial axes and transpose channel dimensions:
5432/// `kernel[c_in, c_out, kD, kH, kW]` ->
5433/// `kernel[c_out, c_in, kD-1-kd, kH-1-kh, kW-1-kw]`.
5434fn flip_kernel_3d<T: Float>(
5435    kernel: &[T],
5436    c_in: usize,
5437    c_out: usize,
5438    kd: usize,
5439    kh: usize,
5440    kw: usize,
5441) -> Vec<T> {
5442    let zero = <T as num_traits::Zero>::zero();
5443    let mut flipped = vec![zero; c_out * c_in * kd * kh * kw];
5444
5445    for ci in 0..c_in {
5446        for co in 0..c_out {
5447            for dd in 0..kd {
5448                for dh in 0..kh {
5449                    for dw in 0..kw {
5450                        // Source: [c_in, c_out, dd, dh, dw]
5451                        let src = ci * c_out * kd * kh * kw
5452                            + co * kd * kh * kw
5453                            + dd * kh * kw
5454                            + dh * kw
5455                            + dw;
5456                        // Dest: [c_out, c_in, kD-1-dd, kH-1-dh, kW-1-dw]
5457                        let dst = co * c_in * kd * kh * kw
5458                            + ci * kd * kh * kw
5459                            + (kd - 1 - dd) * kh * kw
5460                            + (kh - 1 - dh) * kw
5461                            + (kw - 1 - dw);
5462                        flipped[dst] = kernel[src];
5463                    }
5464                }
5465            }
5466        }
5467    }
5468
5469    flipped
5470}
5471
5472/// Single-group transposed 3-D convolution forward (the `groups == 1` core).
5473///
5474/// Operates on a channel-sliced input slab `[B, in_pg, D, H, W]` and a weight
5475/// slab `[in_pg, out_pg, kD, kH, kW]` (the transposed grouped layout,
5476/// `torch/nn/modules/conv.py:164`), returning `[B, out_pg, d_out, h_out,
5477/// w_out]`. Generalises the dense (#1560) algorithm for `dilation` via
5478/// `im2col_3d_dilated` (`internal_pad = dilation*(k-1) - padding`).
5479// Internal kernel: the argument set mirrors the 3-D transposed-conv descriptor;
5480// a config struct would force allocation in the per-group hot loop.
5481#[allow(clippy::too_many_arguments)]
5482fn conv_transpose3d_forward_group<T: Float>(
5483    input_data: &[T],
5484    batch: usize,
5485    in_pg: usize,
5486    out_pg: usize,
5487    d: usize,
5488    h: usize,
5489    w: usize,
5490    kernel_size: (usize, usize, usize),
5491    stride: (usize, usize, usize),
5492    padding: (usize, usize, usize),
5493    output_padding: (usize, usize, usize),
5494    dilation: (usize, usize, usize),
5495    group_weight: &[T],
5496) -> FerrotorchResult<(Vec<T>, usize, usize, usize)> {
5497    let (kd, kh, kw) = kernel_size;
5498    let (sd, sh, sw) = stride;
5499    let (pd, ph, pw) = padding;
5500    let (opd, oph, opw) = output_padding;
5501    let (dd, dh, dw) = dilation;
5502
5503    // Step 1: stride-insert zeros, then append the `output_padding` boundary.
5504    let (upsampled, d_up_core, h_up_core, w_up_core) =
5505        stride_insert_zeros_3d(input_data, batch, in_pg, d, h, w, sd, sh, sw);
5506    let d_up = d_up_core + opd;
5507    let h_up = h_up_core + oph;
5508    let w_up = w_up_core + opw;
5509    let upsampled = if opd > 0 || oph > 0 || opw > 0 {
5510        let zero = <T as num_traits::Zero>::zero();
5511        let mut ext = vec![zero; batch * in_pg * d_up * h_up * w_up];
5512        for b in 0..batch {
5513            for c in 0..in_pg {
5514                for id in 0..d_up_core {
5515                    for ih in 0..h_up_core {
5516                        let src = (((b * in_pg + c) * d_up_core + id) * h_up_core + ih) * w_up_core;
5517                        let dst = (((b * in_pg + c) * d_up + id) * h_up + ih) * w_up;
5518                        ext[dst..dst + w_up_core].copy_from_slice(&upsampled[src..src + w_up_core]);
5519                    }
5520                }
5521            }
5522        }
5523        ext
5524    } else {
5525        upsampled
5526    };
5527
5528    // Step 2: flip the kernel and transpose channel dimensions.
5529    let flipped = flip_kernel_3d(group_weight, in_pg, out_pg, kd, kh, kw);
5530
5531    // Step 3: dilation-spaced stride-1 internal convolution. The internal pad
5532    // is `internal_pad = dilation*(k-1) - padding = eff_k - 1 - padding`,
5533    // `eff_k = dilation*(k-1)+1`. When `padding > dilation*(k-1)` this goes
5534    // NEGATIVE — the transposed-conv output position maps to a read index
5535    // INSIDE the upsampled buffer rather than into a zero-pad halo, so the
5536    // signal must be CROPPED by `|internal_pad|` on each side instead of
5537    // zero-padded (a negative `usize` here would wrap and silently drop the
5538    // whole scatter — the #1619 `output_padding=1`+`dilation=(2,3,2)`,`kw=1`,
5539    // `pw=1` case). This matches upstream's `col2vol` scatter
5540    // (`aten/src/ATen/native/vol2col.h:80-106`), whose `t_pad = t*stride - pad
5541    // + t_offset*dilation` mapping is bounded only by the OUTPUT extent and
5542    // naturally drops positions that fall outside it — there is no separate
5543    // non-negative "internal pad" in upstream.
5544    let eff_kd = dd * (kd - 1) + 1;
5545    let eff_kh = dh * (kh - 1) + 1;
5546    let eff_kw = dw * (kw - 1) + 1;
5547    let signed_pad_d = (eff_kd - 1) as isize - pd as isize;
5548    let signed_pad_h = (eff_kh - 1) as isize - ph as isize;
5549    let signed_pad_w = (eff_kw - 1) as isize - pw as isize;
5550    // Crop the negative dims; zero-pad the non-negative dims (the latter is
5551    // handled by `im2col_3d_dilated`'s pad argument).
5552    let crop_d = (-signed_pad_d).max(0) as usize;
5553    let crop_h = (-signed_pad_h).max(0) as usize;
5554    let crop_w = (-signed_pad_w).max(0) as usize;
5555    let (conv_input, d_in, h_in, w_in) = if crop_d > 0 || crop_h > 0 || crop_w > 0 {
5556        crop_volume_3d(
5557            &upsampled, batch, in_pg, d_up, h_up, w_up, crop_d, crop_h, crop_w,
5558        )
5559    } else {
5560        (upsampled, d_up, h_up, w_up)
5561    };
5562    let internal_pad_d = signed_pad_d.max(0) as usize;
5563    let internal_pad_h = signed_pad_h.max(0) as usize;
5564    let internal_pad_w = signed_pad_w.max(0) as usize;
5565
5566    let (cols, col_rows, col_cols) = im2col_3d_dilated(
5567        &conv_input,
5568        batch,
5569        in_pg,
5570        d_in,
5571        h_in,
5572        w_in,
5573        kd,
5574        kh,
5575        kw,
5576        1,
5577        1,
5578        1,
5579        internal_pad_d,
5580        internal_pad_h,
5581        internal_pad_w,
5582        dd,
5583        dh,
5584        dw,
5585    );
5586
5587    let d_out = (d_in + 2 * internal_pad_d - eff_kd) + 1;
5588    let h_out = (h_in + 2 * internal_pad_h - eff_kh) + 1;
5589    let w_out = (w_in + 2 * internal_pad_w - eff_kw) + 1;
5590    let spatial_out = d_out * h_out * w_out;
5591
5592    let flipped_2d =
5593        Tensor::from_storage(TensorStorage::cpu(flipped), vec![out_pg, col_rows], false)?;
5594
5595    let zero = <T as num_traits::Zero>::zero();
5596    let mut output = vec![zero; batch * out_pg * spatial_out];
5597
5598    for b in 0..batch {
5599        let col_start = b * col_rows * col_cols;
5600        let col_end = col_start + col_rows * col_cols;
5601        let cols_b = Tensor::from_storage(
5602            TensorStorage::cpu(cols[col_start..col_end].to_vec()),
5603            vec![col_rows, col_cols],
5604            false,
5605        )?;
5606        let out_b = mm(&flipped_2d, &cols_b)?;
5607        let out_data = out_b.data()?;
5608        let out_start = b * out_pg * spatial_out;
5609        let copy_len = out_pg * spatial_out;
5610        output[out_start..out_start + copy_len].copy_from_slice(&out_data[..copy_len]);
5611    }
5612
5613    Ok((output, d_out, h_out, w_out))
5614}
5615
5616impl<T: Float> Module<T> for ConvTranspose3d<T> {
5617    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
5618        // Record autocast decision for conv_transpose3d.
5619        let _autocast_cat = autocast_guard("conv_transpose3d");
5620
5621        // Unbatched input: `(C, D, H, W)` (rank 4) is accepted in addition to
5622        // the batched `(N, C, D, H, W)` (rank 5) form. Mirrors `batchify` at
5623        // `aten/src/ATen/native/Convolution.cpp:1216` (conv_transpose3d):
5624        // `unsqueeze(0)` -> transposed-conv -> `squeeze(0)`, autograd-aware.
5625        // Closes #1609.
5626        if input.ndim() == 4 {
5627            let batched = unsqueeze(input, 0)?;
5628            let output = self.forward(&batched)?;
5629            return squeeze(&output, 0);
5630        }
5631
5632        // Validate input shape: [B, C_in, D, H, W].
5633        if input.ndim() != 5 {
5634            return Err(FerrotorchError::InvalidArgument {
5635                message: format!(
5636                    "ConvTranspose3d expects 5-D input [B, C, D, H, W], got {:?}",
5637                    input.shape()
5638                ),
5639            });
5640        }
5641
5642        let batch = input.shape()[0];
5643        let c_in = input.shape()[1];
5644        let d = input.shape()[2];
5645        let h = input.shape()[3];
5646        let w = input.shape()[4];
5647
5648        if c_in != self.in_channels {
5649            return Err(FerrotorchError::ShapeMismatch {
5650                message: format!(
5651                    "ConvTranspose3d: expected {} input channels, got {}",
5652                    self.in_channels, c_in
5653                ),
5654            });
5655        }
5656
5657        let (kd, kh, kw) = self.kernel_size;
5658        let groups = self.groups;
5659        let in_pg = self.in_channels / groups;
5660        let out_pg = self.out_channels / groups;
5661        let weight_pg_numel = in_pg * out_pg * kd * kh * kw;
5662
5663        // Save the input device so we can restore it on the output.
5664        let input_device = input.device();
5665
5666        let input_data = input.data_vec()?;
5667        let weight_data = self.weight.data_vec()?;
5668
5669        // Per-group transposed convolution (Convolution.cpp:1723-1729).
5670        let zero = <T as num_traits::Zero>::zero();
5671        let mut output: Vec<T> = Vec::new();
5672        let (mut d_out, mut h_out, mut w_out) = (0usize, 0usize, 0usize);
5673        let spatial_in = d * h * w;
5674
5675        for g in 0..groups {
5676            let mut group_input = vec![zero; batch * in_pg * spatial_in];
5677            for b in 0..batch {
5678                for c in 0..in_pg {
5679                    let src_c = g * in_pg + c;
5680                    let src = (b * self.in_channels + src_c) * spatial_in;
5681                    let dst = (b * in_pg + c) * spatial_in;
5682                    group_input[dst..dst + spatial_in]
5683                        .copy_from_slice(&input_data[src..src + spatial_in]);
5684                }
5685            }
5686
5687            let w_start = g * weight_pg_numel;
5688            let group_weight = &weight_data[w_start..w_start + weight_pg_numel];
5689
5690            let (g_out, gdo, gho, gwo) = conv_transpose3d_forward_group(
5691                &group_input,
5692                batch,
5693                in_pg,
5694                out_pg,
5695                d,
5696                h,
5697                w,
5698                self.kernel_size,
5699                self.stride,
5700                self.padding,
5701                self.output_padding,
5702                self.dilation,
5703                group_weight,
5704            )?;
5705            d_out = gdo;
5706            h_out = gho;
5707            w_out = gwo;
5708            let spatial_out = d_out * h_out * w_out;
5709
5710            if output.is_empty() {
5711                output = vec![zero; batch * self.out_channels * spatial_out];
5712            }
5713            for b in 0..batch {
5714                for oc in 0..out_pg {
5715                    let dst_c = g * out_pg + oc;
5716                    let dst = (b * self.out_channels + dst_c) * spatial_out;
5717                    let src = (b * out_pg + oc) * spatial_out;
5718                    output[dst..dst + spatial_out].copy_from_slice(&g_out[src..src + spatial_out]);
5719                }
5720            }
5721        }
5722
5723        let spatial_out = d_out * h_out * w_out;
5724
5725        // Add bias if present.
5726        if let Some(ref bias) = self.bias {
5727            let bias_data = bias.data_vec()?;
5728            for b in 0..batch {
5729                for c in 0..self.out_channels {
5730                    let bval = bias_data[c];
5731                    for s in 0..spatial_out {
5732                        output[b * self.out_channels * spatial_out + c * spatial_out + s] += bval;
5733                    }
5734                }
5735            }
5736        }
5737
5738        let result = Tensor::from_storage(
5739            TensorStorage::cpu(output),
5740            vec![batch, self.out_channels, d_out, h_out, w_out],
5741            false,
5742        )?;
5743
5744        // Attach backward if gradients are enabled.
5745        if is_grad_enabled()
5746            && (input.requires_grad()
5747                || self.weight.requires_grad()
5748                || self.bias.as_ref().is_some_and(|b| b.requires_grad()))
5749        {
5750            let grad_fn = Arc::new(ConvTranspose3dBackward {
5751                input: input.clone(),
5752                weight: self.weight.tensor().clone(),
5753                bias: self.bias.as_ref().map(|b| b.tensor().clone()),
5754                in_channels: self.in_channels,
5755                out_channels: self.out_channels,
5756                kernel_size: self.kernel_size,
5757                stride: self.stride,
5758                padding: self.padding,
5759                _output_padding: self.output_padding,
5760                dilation: self.dilation,
5761                groups: self.groups,
5762                d_out,
5763                h_out,
5764                w_out,
5765            });
5766            Tensor::from_operation(
5767                TensorStorage::cpu(result.data()?.to_vec()),
5768                result.shape().to_vec(),
5769                grad_fn,
5770            )?
5771            .to(input_device) // restore device
5772        } else {
5773            result.to(input_device)
5774        }
5775    }
5776
5777    fn parameters(&self) -> Vec<&Parameter<T>> {
5778        let mut params = vec![&self.weight];
5779        if let Some(ref b) = self.bias {
5780            params.push(b);
5781        }
5782        params
5783    }
5784
5785    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
5786        let mut params = vec![&mut self.weight];
5787        if let Some(ref mut b) = self.bias {
5788            params.push(b);
5789        }
5790        params
5791    }
5792
5793    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
5794        let mut params = vec![("weight".to_string(), &self.weight)];
5795        if let Some(ref b) = self.bias {
5796            params.push(("bias".to_string(), b));
5797        }
5798        params
5799    }
5800
5801    fn train(&mut self) {
5802        self.training = true;
5803    }
5804
5805    fn eval(&mut self) {
5806        self.training = false;
5807    }
5808
5809    fn is_training(&self) -> bool {
5810        self.training
5811    }
5812}
5813
5814// ---------------------------------------------------------------------------
5815// ConvTranspose3dBackward
5816// ---------------------------------------------------------------------------
5817
5818/// Backward function for `ConvTranspose3d` forward pass.
5819///
5820/// The backward of a transposed 3-D convolution is a regular 3-D convolution.
5821#[derive(Debug)]
5822struct ConvTranspose3dBackward<T: Float> {
5823    input: Tensor<T>,
5824    weight: Tensor<T>,
5825    bias: Option<Tensor<T>>,
5826    in_channels: usize,
5827    out_channels: usize,
5828    kernel_size: (usize, usize, usize),
5829    stride: (usize, usize, usize),
5830    padding: (usize, usize, usize),
5831    _output_padding: (usize, usize, usize),
5832    dilation: (usize, usize, usize),
5833    groups: usize,
5834    d_out: usize,
5835    h_out: usize,
5836    w_out: usize,
5837}
5838
5839impl<T: Float> GradFn<T> for ConvTranspose3dBackward<T> {
5840    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
5841        // grad_output shape: [B, C_out, D_out, H_out, W_out]
5842        let go_data = grad_output.data_vec()?;
5843        let batch = self.input.shape()[0];
5844        let d_in = self.input.shape()[2];
5845        let h_in = self.input.shape()[3];
5846        let w_in = self.input.shape()[4];
5847        let (kd, kh, kw) = self.kernel_size;
5848        let (sd, sh, sw) = self.stride;
5849        let (pd, ph, pw) = self.padding;
5850        let (dd_, dh_, dw_) = self.dilation;
5851        let groups = self.groups;
5852        let in_pg = self.in_channels / groups;
5853        let out_pg = self.out_channels / groups;
5854        let spatial_out = self.d_out * self.h_out * self.w_out;
5855        let spatial_in = d_in * h_in * w_in;
5856        let zero = <T as num_traits::Zero>::zero();
5857
5858        let weight_data_all = self.weight.data_vec()?;
5859        let input_data_all = self.input.data_vec()?;
5860
5861        let mut gi_all = if self.input.requires_grad() {
5862            Some(vec![zero; batch * self.in_channels * spatial_in])
5863        } else {
5864            None
5865        };
5866        let mut gw_all = if self.weight.requires_grad() {
5867            Some(vec![zero; self.in_channels * out_pg * kd * kh * kw])
5868        } else {
5869            None
5870        };
5871
5872        for g in 0..groups {
5873            // --- grad_input (group g) ---
5874            if let Some(gi) = gi_all.as_mut() {
5875                let col_rows = out_pg * kd * kh * kw;
5876                let w_start = g * in_pg * out_pg * kd * kh * kw;
5877                let weight_2d = Tensor::from_storage(
5878                    TensorStorage::cpu(
5879                        weight_data_all[w_start..w_start + in_pg * out_pg * kd * kh * kw].to_vec(),
5880                    ),
5881                    vec![in_pg, col_rows],
5882                    false,
5883                )?;
5884
5885                let mut go_g = vec![zero; batch * out_pg * spatial_out];
5886                for b in 0..batch {
5887                    for c in 0..out_pg {
5888                        let src_c = g * out_pg + c;
5889                        let src = (b * self.out_channels + src_c) * spatial_out;
5890                        let dst = (b * out_pg + c) * spatial_out;
5891                        go_g[dst..dst + spatial_out]
5892                            .copy_from_slice(&go_data[src..src + spatial_out]);
5893                    }
5894                }
5895
5896                let (go_cols, _gcr, go_col_cols) = im2col_3d_dilated(
5897                    &go_g, batch, out_pg, self.d_out, self.h_out, self.w_out, kd, kh, kw, sd, sh,
5898                    sw, pd, ph, pw, dd_, dh_, dw_,
5899                );
5900                debug_assert_eq!(go_col_cols, spatial_in);
5901
5902                for b in 0..batch {
5903                    let col_start = b * col_rows * go_col_cols;
5904                    let col_end = col_start + col_rows * go_col_cols;
5905                    let go_cols_b = Tensor::from_storage(
5906                        TensorStorage::cpu(go_cols[col_start..col_end].to_vec()),
5907                        vec![col_rows, go_col_cols],
5908                        false,
5909                    )?;
5910                    let gi_b = mm(&weight_2d, &go_cols_b)?;
5911                    let gi_data = gi_b.data()?;
5912                    for c in 0..in_pg {
5913                        let dst_c = g * in_pg + c;
5914                        let dst = (b * self.in_channels + dst_c) * spatial_in;
5915                        let src = c * spatial_in;
5916                        gi[dst..dst + spatial_in].copy_from_slice(&gi_data[src..src + spatial_in]);
5917                    }
5918                }
5919            }
5920
5921            // --- grad_weight (group g) ---
5922            if let Some(gw) = gw_all.as_mut() {
5923                for ci in 0..in_pg {
5924                    let in_c = g * in_pg + ci;
5925                    for co in 0..out_pg {
5926                        let out_c = g * out_pg + co;
5927                        for tkd in 0..kd {
5928                            for tkh in 0..kh {
5929                                for tkw in 0..kw {
5930                                    let mut acc = zero;
5931                                    for id in 0..d_in {
5932                                        for ih in 0..h_in {
5933                                            for iw in 0..w_in {
5934                                                let od = id * sd + tkd * dd_;
5935                                                let oh = ih * sh + tkh * dh_;
5936                                                let ow = iw * sw + tkw * dw_;
5937                                                if od >= pd
5938                                                    && oh >= ph
5939                                                    && ow >= pw
5940                                                    && (od - pd) < self.d_out
5941                                                    && (oh - ph) < self.h_out
5942                                                    && (ow - pw) < self.w_out
5943                                                {
5944                                                    let go_index = out_c * spatial_out
5945                                                        + (od - pd) * self.h_out * self.w_out
5946                                                        + (oh - ph) * self.w_out
5947                                                        + (ow - pw);
5948                                                    let in_index = in_c * spatial_in
5949                                                        + id * h_in * w_in
5950                                                        + ih * w_in
5951                                                        + iw;
5952                                                    for b in 0..batch {
5953                                                        let goi =
5954                                                            b * self.out_channels * spatial_out
5955                                                                + go_index;
5956                                                        let ini = b * self.in_channels * spatial_in
5957                                                            + in_index;
5958                                                        acc += input_data_all[ini] * go_data[goi];
5959                                                    }
5960                                                }
5961                                            }
5962                                        }
5963                                    }
5964                                    // gw layout: [in_channels, out_pg, kD, kH, kW].
5965                                    gw[((((in_c * out_pg + co) * kd + tkd) * kh + tkh) * kw)
5966                                        + tkw] += acc;
5967                                }
5968                            }
5969                        }
5970                    }
5971                }
5972            }
5973        }
5974
5975        let grad_input = match gi_all {
5976            Some(gi) => Some(Tensor::from_storage(
5977                TensorStorage::cpu(gi),
5978                self.input.shape().to_vec(),
5979                false,
5980            )?),
5981            None => None,
5982        };
5983        let grad_weight = match gw_all {
5984            Some(gw) => Some(Tensor::from_storage(
5985                TensorStorage::cpu(gw),
5986                vec![self.in_channels, out_pg, kd, kh, kw],
5987                false,
5988            )?),
5989            None => None,
5990        };
5991
5992        // --- grad_bias ---
5993        let grad_bias = match &self.bias {
5994            Some(b) if b.requires_grad() => {
5995                let zero = <T as num_traits::Zero>::zero();
5996                let mut gb = vec![zero; self.out_channels];
5997                for batch_idx in 0..batch {
5998                    for c in 0..self.out_channels {
5999                        for s in 0..spatial_out {
6000                            gb[c] += go_data
6001                                [batch_idx * self.out_channels * spatial_out + c * spatial_out + s];
6002                        }
6003                    }
6004                }
6005                Some(Tensor::from_storage(
6006                    TensorStorage::cpu(gb),
6007                    vec![self.out_channels],
6008                    false,
6009                )?)
6010            }
6011            _ => None,
6012        };
6013
6014        let mut grads = vec![grad_input, grad_weight];
6015        if self.bias.is_some() {
6016            grads.push(grad_bias);
6017        }
6018        Ok(grads)
6019    }
6020
6021    fn inputs(&self) -> Vec<&Tensor<T>> {
6022        let mut v = vec![&self.input, &self.weight];
6023        if let Some(ref b) = self.bias {
6024            v.push(b);
6025        }
6026        v
6027    }
6028
6029    fn name(&self) -> &'static str {
6030        "ConvTranspose3dBackward"
6031    }
6032}
6033
6034// ---------------------------------------------------------------------------
6035// Tests
6036// ---------------------------------------------------------------------------
6037
6038#[cfg(test)]
6039mod tests {
6040    use super::*;
6041    use crate::module::Module;
6042
6043    // -----------------------------------------------------------------------
6044    // Bias init bounds — REQ-9 / closes #1450
6045    // -----------------------------------------------------------------------
6046
6047    /// Verifies Conv2d bias is initialized within `U(-bound, bound)` where
6048    /// `bound = 1/sqrt((in_channels/groups) * kH * kW)` per
6049    /// `torch/nn/modules/conv.py:198-201`. Pre-fix the bias was zeros_init.
6050    #[test]
6051    fn test_conv2d_bias_init_bounded_uniform() {
6052        let in_c = 16usize;
6053        let out_c = 32usize;
6054        let kh = 3usize;
6055        let kw = 3usize;
6056        let groups = 1usize;
6057        let layer =
6058            Conv2d::<f32>::new_full(in_c, out_c, (kh, kw), (1, 1), (0, 0), (1, 1), groups, true)
6059                .unwrap();
6060        let bias = layer.bias.as_ref().expect("bias requested");
6061        let bias_data = bias.tensor().data_vec().unwrap();
6062        let fan_in = (in_c / groups) * kh * kw;
6063        let bound = 1.0_f32 / (fan_in as f32).sqrt();
6064        let mut nonzero = 0usize;
6065        for &b in &bias_data {
6066            assert!(
6067                b.abs() <= bound + 1e-6,
6068                "Conv2d bias element {b} exceeds bound {bound}"
6069            );
6070            if b != 0.0 {
6071                nonzero += 1;
6072            }
6073        }
6074        assert!(
6075            nonzero > out_c / 2,
6076            "expected most Conv2d bias entries to be nonzero; \
6077             would FAIL pre-fix when bias was zeros_init"
6078        );
6079    }
6080
6081    /// Helper: create a tensor from flat data and shape.
6082    fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
6083        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
6084    }
6085
6086    /// Helper: create a leaf tensor that requires grad.
6087    fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
6088        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
6089    }
6090
6091    /// Assert two slices are element-wise close.
6092    fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
6093        assert_eq!(
6094            actual.len(),
6095            expected.len(),
6096            "length mismatch: {} vs {}",
6097            actual.len(),
6098            expected.len()
6099        );
6100        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
6101            assert!(
6102                (a - e).abs() < tol,
6103                "index {i}: actual={a} expected={e} (diff {})",
6104                (a - e).abs()
6105            );
6106        }
6107    }
6108
6109    // -----------------------------------------------------------------------
6110    // Output shape
6111    // -----------------------------------------------------------------------
6112
6113    #[test]
6114    fn test_output_shape_no_padding() {
6115        // Input: [1, 1, 5, 5], kernel 3x3, stride 1, padding 0
6116        // H_out = (5 - 3) / 1 + 1 = 3, W_out = 3
6117        let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
6118        let input = t(&[0.0; 25], &[1, 1, 5, 5]);
6119        let output = conv.forward(&input).unwrap();
6120        assert_eq!(output.shape(), &[1, 1, 3, 3]);
6121    }
6122
6123    #[test]
6124    fn test_output_shape_with_padding() {
6125        // Input: [2, 3, 8, 8], kernel 3x3, stride 1, padding 1
6126        // H_out = (8 + 2 - 3) / 1 + 1 = 8
6127        let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (1, 1), true).unwrap();
6128        let input = t(&vec![0.0; 2 * 3 * 8 * 8], &[2, 3, 8, 8]);
6129        let output = conv.forward(&input).unwrap();
6130        assert_eq!(output.shape(), &[2, 16, 8, 8]);
6131    }
6132
6133    #[test]
6134    fn test_output_shape_with_stride() {
6135        // Input: [1, 1, 6, 6], kernel 3x3, stride 2, padding 0
6136        // H_out = (6 - 3) / 2 + 1 = 2
6137        let conv = Conv2d::<f32>::new(1, 4, (3, 3), (2, 2), (0, 0), false).unwrap();
6138        let input = t(&[0.0; 36], &[1, 1, 6, 6]);
6139        let output = conv.forward(&input).unwrap();
6140        assert_eq!(output.shape(), &[1, 4, 2, 2]);
6141    }
6142
6143    // -----------------------------------------------------------------------
6144    // 1x1 convolution == linear (per-pixel)
6145    // -----------------------------------------------------------------------
6146
6147    #[test]
6148    fn test_1x1_conv_equals_linear() {
6149        // A 1x1 conv with 2 input channels and 3 output channels is equivalent
6150        // to a linear layer applied independently at each spatial position.
6151        //
6152        // weight shape: [3, 2, 1, 1] -- interpreted as a [3, 2] matrix
6153        // input shape: [1, 2, 2, 2]  -- 2 channels, 2x2 spatial
6154        //
6155        // For each pixel (h, w): output[:, h, w] = weight.squeeze() @ input[:, h, w]
6156
6157        let weight_data: Vec<f32> = vec![
6158            1.0, 2.0, // out_channel 0: [1, 2]
6159            3.0, 4.0, // out_channel 1: [3, 4]
6160            5.0, 6.0, // out_channel 2: [5, 6]
6161        ];
6162        // Input: channel 0 = [[1, 2], [3, 4]], channel 1 = [[5, 6], [7, 8]]
6163        let input_data: Vec<f32> = vec![
6164            1.0, 2.0, 3.0, 4.0, // channel 0
6165            5.0, 6.0, 7.0, 8.0, // channel 1
6166        ];
6167
6168        // Manually construct Conv2d with known weights.
6169        let weight_param = Parameter::from_slice(&weight_data, &[3, 2, 1, 1]).unwrap();
6170        let conv = Conv2d {
6171            weight: weight_param,
6172            bias: None,
6173            in_channels: 2,
6174            out_channels: 3,
6175            kernel_size: (1, 1),
6176            stride: (1, 1),
6177            padding: (0, 0),
6178            dilation: (1, 1),
6179            groups: 1,
6180            padding_mode: crate::padding::PaddingMode::Zeros,
6181            string_padding: None,
6182            training: false,
6183        };
6184
6185        let input = t(&input_data, &[1, 2, 2, 2]);
6186        let output = conv.forward(&input).unwrap();
6187        assert_eq!(output.shape(), &[1, 3, 2, 2]);
6188
6189        let out = output.data().unwrap();
6190
6191        // Pixel (0,0): in = [1, 5], out = [1*1+2*5, 3*1+4*5, 5*1+6*5] = [11, 23, 35]
6192        // Pixel (0,1): in = [2, 6], out = [1*2+2*6, 3*2+4*6, 5*2+6*6] = [14, 30, 46]
6193        // Pixel (1,0): in = [3, 7], out = [1*3+2*7, 3*3+4*7, 5*3+6*7] = [17, 37, 57]
6194        // Pixel (1,1): in = [4, 8], out = [1*4+2*8, 3*4+4*8, 5*4+6*8] = [20, 44, 68]
6195
6196        // Output layout: [C_out, H, W] = [3, 2, 2]
6197        // Channel 0: [11, 14, 17, 20]
6198        // Channel 1: [23, 30, 37, 44]
6199        // Channel 2: [35, 46, 57, 68]
6200        let expected = [
6201            11.0, 14.0, 17.0, 20.0, // out channel 0
6202            23.0, 30.0, 37.0, 44.0, // out channel 1
6203            35.0, 46.0, 57.0, 68.0, // out channel 2
6204        ];
6205        assert_close(out, &expected, 1e-5);
6206    }
6207
6208    // -----------------------------------------------------------------------
6209    // Bias
6210    // -----------------------------------------------------------------------
6211
6212    #[test]
6213    fn test_bias_addition() {
6214        // 1x1 conv with bias.
6215        let weight_data = vec![1.0f32]; // [1, 1, 1, 1]
6216        let bias_data = vec![10.0f32]; // [1]
6217
6218        let conv = Conv2d {
6219            weight: Parameter::from_slice(&weight_data, &[1, 1, 1, 1]).unwrap(),
6220            bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
6221            in_channels: 1,
6222            out_channels: 1,
6223            kernel_size: (1, 1),
6224            stride: (1, 1),
6225            padding: (0, 0),
6226            dilation: (1, 1),
6227            groups: 1,
6228            padding_mode: crate::padding::PaddingMode::Zeros,
6229            string_padding: None,
6230            training: false,
6231        };
6232
6233        let input = t(&[2.0, 3.0, 4.0, 5.0], &[1, 1, 2, 2]);
6234        let output = conv.forward(&input).unwrap();
6235        // output = input * 1.0 + 10.0
6236        assert_close(output.data().unwrap(), &[12.0, 13.0, 14.0, 15.0], 1e-5);
6237    }
6238
6239    // -----------------------------------------------------------------------
6240    // Backward shape
6241    // -----------------------------------------------------------------------
6242
6243    #[test]
6244    fn test_backward_produces_correct_shapes() {
6245        // We manually invoke the backward function and check shapes.
6246        let weight_data = vec![1.0f32; 2 * 3 * 3]; // [2, 1, 3, 3]
6247        let input_data = vec![1.0f32; 5 * 5]; // [1, 1, 5, 5]
6248        let bias_data = vec![0.0f32; 2];
6249
6250        let weight_param = Parameter::from_slice(&weight_data, &[2, 1, 3, 3]).unwrap();
6251        let bias_param = Parameter::from_slice(&bias_data, &[2]).unwrap();
6252
6253        let conv = Conv2d {
6254            weight: weight_param,
6255            bias: Some(bias_param),
6256            in_channels: 1,
6257            out_channels: 2,
6258            kernel_size: (3, 3),
6259            stride: (1, 1),
6260            padding: (0, 0),
6261            dilation: (1, 1),
6262            groups: 1,
6263            padding_mode: crate::padding::PaddingMode::Zeros,
6264            string_padding: None,
6265            training: false,
6266        };
6267
6268        // Forward to get the grad_fn.
6269        let input = leaf(&input_data, &[1, 1, 5, 5]);
6270        let output = conv.forward(&input).unwrap();
6271        assert_eq!(output.shape(), &[1, 2, 3, 3]);
6272
6273        // Make sure grad_fn is attached.
6274        assert!(output.grad_fn().is_some());
6275        assert_eq!(output.grad_fn().unwrap().name(), "Conv2dBackward");
6276
6277        // Construct a grad_output of the right shape.
6278        let grad_output = t(&[1.0; 2 * 3 * 3], &[1, 2, 3, 3]);
6279        let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
6280
6281        // grad_input shape should be [1, 1, 5, 5]
6282        assert!(grads[0].is_some());
6283        assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 5, 5]);
6284
6285        // grad_weight shape should be [2, 1, 3, 3]
6286        assert!(grads[1].is_some());
6287        assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 3, 3]);
6288
6289        // grad_bias shape should be [2]
6290        assert!(grads[2].is_some());
6291        assert_eq!(grads[2].as_ref().unwrap().shape(), &[2]);
6292    }
6293
6294    // -----------------------------------------------------------------------
6295    // Parameter count
6296    // -----------------------------------------------------------------------
6297
6298    #[test]
6299    fn test_parameter_count_with_bias() {
6300        let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (0, 0), true).unwrap();
6301        // weight: 16 * 3 * 3 * 3 = 432
6302        // bias:   16
6303        // total:  448
6304        assert_eq!(conv.num_parameters(), 448);
6305        assert_eq!(conv.parameters().len(), 2);
6306    }
6307
6308    #[test]
6309    fn test_parameter_count_without_bias() {
6310        let conv = Conv2d::<f32>::new(3, 16, (3, 3), (1, 1), (0, 0), false).unwrap();
6311        assert_eq!(conv.num_parameters(), 432);
6312        assert_eq!(conv.parameters().len(), 1);
6313    }
6314
6315    // -----------------------------------------------------------------------
6316    // Module trait
6317    // -----------------------------------------------------------------------
6318
6319    #[test]
6320    fn test_named_parameters() {
6321        let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), true).unwrap();
6322        let named = conv.named_parameters();
6323        assert_eq!(named.len(), 2);
6324        assert_eq!(named[0].0, "weight");
6325        assert_eq!(named[1].0, "bias");
6326    }
6327
6328    #[test]
6329    fn test_train_eval() {
6330        let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
6331        assert!(conv.is_training());
6332        conv.eval();
6333        assert!(!conv.is_training());
6334        conv.train();
6335        assert!(conv.is_training());
6336    }
6337
6338    // -----------------------------------------------------------------------
6339    // Edge cases
6340    // -----------------------------------------------------------------------
6341
6342    #[test]
6343    fn test_invalid_input_ndim() {
6344        let conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
6345        let input = t(&[1.0, 2.0, 3.0], &[3]);
6346        assert!(conv.forward(&input).is_err());
6347    }
6348
6349    #[test]
6350    fn test_channel_mismatch() {
6351        let conv = Conv2d::<f32>::new(3, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
6352        let input = t(&[0.0; 5 * 5], &[1, 1, 5, 5]);
6353        assert!(conv.forward(&input).is_err());
6354    }
6355
6356    #[test]
6357    fn test_zero_channels_rejected() {
6358        assert!(Conv2d::<f32>::new(0, 16, (3, 3), (1, 1), (0, 0), false).is_err());
6359        assert!(Conv2d::<f32>::new(3, 0, (3, 3), (1, 1), (0, 0), false).is_err());
6360    }
6361
6362    #[test]
6363    fn test_zero_kernel_rejected() {
6364        assert!(Conv2d::<f32>::new(1, 1, (0, 3), (1, 1), (0, 0), false).is_err());
6365    }
6366
6367    #[test]
6368    fn test_zero_stride_rejected() {
6369        assert!(Conv2d::<f32>::new(1, 1, (3, 3), (0, 1), (0, 0), false).is_err());
6370    }
6371
6372    // -----------------------------------------------------------------------
6373    // im2col / col2im roundtrip
6374    // -----------------------------------------------------------------------
6375
6376    #[test]
6377    fn test_im2col_basic() {
6378        // 1 batch, 1 channel, 3x3 input, 2x2 kernel, stride 1, no padding
6379        // H_out = 2, W_out = 2
6380        // Columns: each column is a flattened 2x2 patch
6381        #[rustfmt::skip]
6382        let input: Vec<f32> = vec![
6383            1.0, 2.0, 3.0,
6384            4.0, 5.0, 6.0,
6385            7.0, 8.0, 9.0,
6386        ];
6387        let (cols, rows, n_cols) = im2col(&input, 1, 1, 3, 3, 2, 2, 1, 1, 0, 0);
6388        assert_eq!(rows, 4); // 1 * 2 * 2
6389        assert_eq!(n_cols, 4); // 2 * 2
6390
6391        // Patch (0,0): [1, 2, 4, 5]
6392        // Patch (0,1): [2, 3, 5, 6]
6393        // Patch (1,0): [4, 5, 7, 8]
6394        // Patch (1,1): [5, 6, 8, 9]
6395        //
6396        // cols layout: [row][col] where row = c*kH*kW+kh*kW+kw, col = oh*W_out+ow
6397        // Row 0 (kh=0,kw=0): [1, 2, 4, 5]
6398        // Row 1 (kh=0,kw=1): [2, 3, 5, 6]
6399        // Row 2 (kh=1,kw=0): [4, 5, 7, 8]
6400        // Row 3 (kh=1,kw=1): [5, 6, 8, 9]
6401        assert_close(
6402            &cols,
6403            &[
6404                1.0, 2.0, 4.0, 5.0, // row 0
6405                2.0, 3.0, 5.0, 6.0, // row 1
6406                4.0, 5.0, 7.0, 8.0, // row 2
6407                5.0, 6.0, 8.0, 9.0, // row 3
6408            ],
6409            1e-7,
6410        );
6411    }
6412
6413    #[test]
6414    fn test_col2im_roundtrip_no_overlap() {
6415        // With stride == kernel_size and no padding, im2col + col2im is lossless.
6416        // 1 batch, 1 channel, 4x4, kernel 2x2, stride 2, no padding
6417        // H_out = 2, W_out = 2
6418        #[rustfmt::skip]
6419        let input: Vec<f32> = vec![
6420            1.0, 2.0, 3.0, 4.0,
6421            5.0, 6.0, 7.0, 8.0,
6422            9.0, 10.0, 11.0, 12.0,
6423            13.0, 14.0, 15.0, 16.0,
6424        ];
6425
6426        let (cols, _rows, _n_cols) = im2col(&input, 1, 1, 4, 4, 2, 2, 2, 2, 0, 0);
6427        let recovered = col2im(&cols, 1, 1, 4, 4, 2, 2, 2, 2, 0, 0, 2, 2);
6428
6429        assert_close(&recovered, &input, 1e-7);
6430    }
6431
6432    // -----------------------------------------------------------------------
6433    // Forward correctness with a simple 3x3 kernel
6434    // -----------------------------------------------------------------------
6435
6436    #[test]
6437    fn test_3x3_conv_forward() {
6438        // 1 batch, 1 channel, 3x3 input, 3x3 kernel, stride 1, no padding
6439        // Output: 1x1 (single value = sum of element-wise product)
6440        #[rustfmt::skip]
6441        let input_data: Vec<f32> = vec![
6442            1.0, 2.0, 3.0,
6443            4.0, 5.0, 6.0,
6444            7.0, 8.0, 9.0,
6445        ];
6446        #[rustfmt::skip]
6447        let weight_data: Vec<f32> = vec![
6448            1.0, 0.0, -1.0,
6449            1.0, 0.0, -1.0,
6450            1.0, 0.0, -1.0,
6451        ];
6452
6453        let conv = Conv2d {
6454            weight: Parameter::from_slice(&weight_data, &[1, 1, 3, 3]).unwrap(),
6455            bias: None,
6456            in_channels: 1,
6457            out_channels: 1,
6458            kernel_size: (3, 3),
6459            stride: (1, 1),
6460            padding: (0, 0),
6461            dilation: (1, 1),
6462            groups: 1,
6463            padding_mode: crate::padding::PaddingMode::Zeros,
6464            string_padding: None,
6465            training: false,
6466        };
6467
6468        let input = t(&input_data, &[1, 1, 3, 3]);
6469        let output = conv.forward(&input).unwrap();
6470        assert_eq!(output.shape(), &[1, 1, 1, 1]);
6471
6472        // Expected: 1*1 + 0*2 + (-1)*3 + 1*4 + 0*5 + (-1)*6 + 1*7 + 0*8 + (-1)*9
6473        //         = 1 - 3 + 4 - 6 + 7 - 9 = -6
6474        assert_close(output.data().unwrap(), &[-6.0], 1e-5);
6475    }
6476
6477    // -----------------------------------------------------------------------
6478    // Padding correctness
6479    // -----------------------------------------------------------------------
6480
6481    #[test]
6482    fn test_padding_preserves_spatial_size() {
6483        // Input: [1, 1, 3, 3], kernel 3x3, stride 1, padding 1
6484        // H_out = (3 + 2 - 3) / 1 + 1 = 3 (same size!)
6485        let weight_data = vec![0.0f32; 9];
6486        let mut weight_data_center = weight_data;
6487        weight_data_center[4] = 1.0; // Center of 3x3 = identity
6488
6489        let conv = Conv2d {
6490            weight: Parameter::from_slice(&weight_data_center, &[1, 1, 3, 3]).unwrap(),
6491            bias: None,
6492            in_channels: 1,
6493            out_channels: 1,
6494            kernel_size: (3, 3),
6495            stride: (1, 1),
6496            padding: (1, 1),
6497            dilation: (1, 1),
6498            groups: 1,
6499            padding_mode: crate::padding::PaddingMode::Zeros,
6500            string_padding: None,
6501            training: false,
6502        };
6503
6504        let input_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
6505        let input = t(&input_data, &[1, 1, 3, 3]);
6506        let output = conv.forward(&input).unwrap();
6507        assert_eq!(output.shape(), &[1, 1, 3, 3]);
6508
6509        // With center-only kernel + padding, output should equal input.
6510        assert_close(output.data().unwrap(), &input_data, 1e-5);
6511    }
6512
6513    // ===================================================================
6514    // Conv1d tests
6515    // ===================================================================
6516
6517    // -----------------------------------------------------------------------
6518    // Conv1d: output shape
6519    // -----------------------------------------------------------------------
6520
6521    #[test]
6522    fn test_conv1d_output_shape_no_padding() {
6523        // Input: [1, 1, 10], kernel 3, stride 1, padding 0
6524        // L_out = (10 - 3) / 1 + 1 = 8
6525        let conv = Conv1d::<f32>::new(1, 4, 3, 1, 0, false).unwrap();
6526        let input = t(&[0.0; 10], &[1, 1, 10]);
6527        let output = conv.forward(&input).unwrap();
6528        assert_eq!(output.shape(), &[1, 4, 8]);
6529    }
6530
6531    #[test]
6532    fn test_conv1d_output_shape_with_padding() {
6533        // Input: [2, 3, 16], kernel 3, stride 1, padding 1
6534        // L_out = (16 + 2 - 3) / 1 + 1 = 16
6535        let conv = Conv1d::<f32>::new(3, 8, 3, 1, 1, true).unwrap();
6536        let input = t(&vec![0.0; 2 * 3 * 16], &[2, 3, 16]);
6537        let output = conv.forward(&input).unwrap();
6538        assert_eq!(output.shape(), &[2, 8, 16]);
6539    }
6540
6541    #[test]
6542    fn test_conv1d_output_shape_with_stride() {
6543        // Input: [1, 1, 10], kernel 3, stride 2, padding 0
6544        // L_out = (10 - 3) / 2 + 1 = 4
6545        let conv = Conv1d::<f32>::new(1, 2, 3, 2, 0, false).unwrap();
6546        let input = t(&[0.0; 10], &[1, 1, 10]);
6547        let output = conv.forward(&input).unwrap();
6548        assert_eq!(output.shape(), &[1, 2, 4]);
6549    }
6550
6551    // -----------------------------------------------------------------------
6552    // Conv1d: 1x1 kernel correctness (acts as per-position linear)
6553    // -----------------------------------------------------------------------
6554
6555    #[test]
6556    fn test_conv1d_1x1_kernel_correctness() {
6557        // A kernel_size=1 Conv1d is equivalent to a linear layer applied at
6558        // each position independently.
6559        //
6560        // weight: [2, 1, 1] = [[3.0], [5.0]]
6561        // input:  [1, 1, 4] = [1, 2, 3, 4]
6562        // output: [1, 2, 4]
6563        //   out_ch 0: [3, 6, 9, 12]
6564        //   out_ch 1: [5, 10, 15, 20]
6565        let weight_data = vec![3.0f32, 5.0];
6566        let conv = Conv1d {
6567            weight: Parameter::from_slice(&weight_data, &[2, 1, 1]).unwrap(),
6568            bias: None,
6569            in_channels: 1,
6570            out_channels: 2,
6571            kernel_size: 1,
6572            stride: 1,
6573            padding: 0,
6574            dilation: 1,
6575            groups: 1,
6576            padding_mode: crate::padding::PaddingMode::Zeros,
6577            string_padding: None,
6578            training: false,
6579        };
6580
6581        let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
6582        let output = conv.forward(&input).unwrap();
6583        assert_eq!(output.shape(), &[1, 2, 4]);
6584        assert_close(
6585            output.data().unwrap(),
6586            &[3.0, 6.0, 9.0, 12.0, 5.0, 10.0, 15.0, 20.0],
6587            1e-5,
6588        );
6589    }
6590
6591    // -----------------------------------------------------------------------
6592    // Conv1d: forward correctness with a 3-wide kernel
6593    // -----------------------------------------------------------------------
6594
6595    #[test]
6596    fn test_conv1d_3_kernel_forward() {
6597        // Input: [1, 1, 5] = [1, 2, 3, 4, 5]
6598        // Kernel: [1, 1, 3] = [1, 0, -1]
6599        // Stride 1, padding 0 => L_out = 3
6600        // Expected: [1*1+0*2+(-1)*3, 1*2+0*3+(-1)*4, 1*3+0*4+(-1)*5] = [-2, -2, -2]
6601        let conv = Conv1d {
6602            weight: Parameter::from_slice(&[1.0f32, 0.0, -1.0], &[1, 1, 3]).unwrap(),
6603            bias: None,
6604            in_channels: 1,
6605            out_channels: 1,
6606            kernel_size: 3,
6607            stride: 1,
6608            padding: 0,
6609            dilation: 1,
6610            groups: 1,
6611            padding_mode: crate::padding::PaddingMode::Zeros,
6612            string_padding: None,
6613            training: false,
6614        };
6615
6616        let input = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
6617        let output = conv.forward(&input).unwrap();
6618        assert_eq!(output.shape(), &[1, 1, 3]);
6619        assert_close(output.data().unwrap(), &[-2.0, -2.0, -2.0], 1e-5);
6620    }
6621
6622    // -----------------------------------------------------------------------
6623    // Conv1d: bias
6624    // -----------------------------------------------------------------------
6625
6626    #[test]
6627    fn test_conv1d_bias() {
6628        let conv = Conv1d {
6629            weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1]).unwrap(),
6630            bias: Some(Parameter::from_slice(&[10.0f32], &[1]).unwrap()),
6631            in_channels: 1,
6632            out_channels: 1,
6633            kernel_size: 1,
6634            stride: 1,
6635            padding: 0,
6636            dilation: 1,
6637            groups: 1,
6638            padding_mode: crate::padding::PaddingMode::Zeros,
6639            string_padding: None,
6640            training: false,
6641        };
6642
6643        let input = t(&[2.0, 3.0, 4.0], &[1, 1, 3]);
6644        let output = conv.forward(&input).unwrap();
6645        assert_close(output.data().unwrap(), &[12.0, 13.0, 14.0], 1e-5);
6646    }
6647
6648    // -----------------------------------------------------------------------
6649    // Conv1d: edge cases
6650    // -----------------------------------------------------------------------
6651
6652    #[test]
6653    fn test_conv1d_invalid_ndim() {
6654        let conv = Conv1d::<f32>::new(1, 1, 3, 1, 0, false).unwrap();
6655        let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
6656        assert!(conv.forward(&input).is_err());
6657    }
6658
6659    #[test]
6660    fn test_conv1d_channel_mismatch() {
6661        let conv = Conv1d::<f32>::new(3, 1, 3, 1, 0, false).unwrap();
6662        let input = t(&[0.0; 10], &[1, 1, 10]);
6663        assert!(conv.forward(&input).is_err());
6664    }
6665
6666    #[test]
6667    fn test_conv1d_zero_channels_rejected() {
6668        assert!(Conv1d::<f32>::new(0, 4, 3, 1, 0, false).is_err());
6669        assert!(Conv1d::<f32>::new(1, 0, 3, 1, 0, false).is_err());
6670    }
6671
6672    #[test]
6673    fn test_conv1d_zero_kernel_rejected() {
6674        assert!(Conv1d::<f32>::new(1, 1, 0, 1, 0, false).is_err());
6675    }
6676
6677    #[test]
6678    fn test_conv1d_zero_stride_rejected() {
6679        assert!(Conv1d::<f32>::new(1, 1, 3, 0, 0, false).is_err());
6680    }
6681
6682    #[test]
6683    fn test_conv1d_parameter_count() {
6684        let conv = Conv1d::<f32>::new(3, 8, 5, 1, 0, true).unwrap();
6685        // weight: 8 * 3 * 5 = 120, bias: 8, total: 128
6686        assert_eq!(conv.num_parameters(), 128);
6687        assert_eq!(conv.parameters().len(), 2);
6688    }
6689
6690    // ===================================================================
6691    // ConvTranspose2d tests
6692    // ===================================================================
6693
6694    // -----------------------------------------------------------------------
6695    // ConvTranspose2d: output shape
6696    // -----------------------------------------------------------------------
6697
6698    #[test]
6699    fn test_conv_transpose2d_output_shape_basic() {
6700        // Input: [1, 1, 3, 3], kernel 3x3, stride 1, padding 0, output_padding 0
6701        // H_out = (3 - 1) * 1 - 0 + 3 + 0 = 5
6702        let conv =
6703            ConvTranspose2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
6704        let input = t(&[0.0; 9], &[1, 1, 3, 3]);
6705        let output = conv.forward(&input).unwrap();
6706        assert_eq!(output.shape(), &[1, 1, 5, 5]);
6707    }
6708
6709    #[test]
6710    fn test_conv_transpose2d_output_shape_stride2() {
6711        // Input: [1, 1, 2, 2], kernel 3x3, stride 2, padding 0, output_padding 0
6712        // H_out = (2 - 1) * 2 - 0 + 3 + 0 = 5
6713        let conv =
6714            ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), (0, 0), false).unwrap();
6715        let input = t(&[0.0; 4], &[1, 1, 2, 2]);
6716        let output = conv.forward(&input).unwrap();
6717        assert_eq!(output.shape(), &[1, 1, 5, 5]);
6718    }
6719
6720    #[test]
6721    fn test_conv_transpose2d_output_shape_with_padding() {
6722        // Input: [1, 1, 3, 3], kernel 3x3, stride 2, padding 1, output_padding 0
6723        // H_out = (3 - 1) * 2 - 2 + 3 + 0 = 5
6724        let conv =
6725            ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (1, 1), (0, 0), false).unwrap();
6726        let input = t(&[0.0; 9], &[1, 1, 3, 3]);
6727        let output = conv.forward(&input).unwrap();
6728        assert_eq!(output.shape(), &[1, 1, 5, 5]);
6729    }
6730
6731    #[test]
6732    fn test_conv_transpose2d_output_shape_with_output_padding() {
6733        // Input: [1, 1, 3, 3], kernel 3x3, stride 2, padding 1, output_padding 1
6734        // H_out = (3 - 1) * 2 - 2 + 3 + 1 = 6
6735        let conv =
6736            ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (1, 1), (1, 1), false).unwrap();
6737        let input = t(&[0.0; 9], &[1, 1, 3, 3]);
6738        let output = conv.forward(&input).unwrap();
6739        assert_eq!(output.shape(), &[1, 1, 6, 6]);
6740    }
6741
6742    // -----------------------------------------------------------------------
6743    // ConvTranspose2d: stride=2 doubles spatial dims (upsampling)
6744    // -----------------------------------------------------------------------
6745
6746    #[test]
6747    fn test_conv_transpose2d_stride2_upsamples() {
6748        // With stride=2, kernel=2x2, padding=0, output_padding=0:
6749        // H_out = (H - 1) * 2 + 2 = 2 * H
6750        // So a 4x4 input becomes 8x8 — doubling spatial dims.
6751        let conv =
6752            ConvTranspose2d::<f32>::new(1, 1, (2, 2), (2, 2), (0, 0), (0, 0), false).unwrap();
6753        let input = t(&[0.0; 4 * 4], &[1, 1, 4, 4]);
6754        let output = conv.forward(&input).unwrap();
6755        assert_eq!(output.shape(), &[1, 1, 8, 8]);
6756    }
6757
6758    #[test]
6759    fn test_conv_transpose2d_stride2_upsamples_multichannel() {
6760        // [2, 8, 4, 4] -> [2, 16, 8, 8] with stride=2, kernel=2x2
6761        let conv =
6762            ConvTranspose2d::<f32>::new(8, 16, (2, 2), (2, 2), (0, 0), (0, 0), true).unwrap();
6763        let input = t(&vec![0.0; 2 * 8 * 4 * 4], &[2, 8, 4, 4]);
6764        let output = conv.forward(&input).unwrap();
6765        assert_eq!(output.shape(), &[2, 16, 8, 8]);
6766    }
6767
6768    // -----------------------------------------------------------------------
6769    // ConvTranspose2d: 1x1 kernel correctness
6770    // -----------------------------------------------------------------------
6771
6772    #[test]
6773    fn test_conv_transpose2d_1x1_kernel() {
6774        // With a 1x1 kernel, stride 1, no padding, the transposed conv is
6775        // equivalent to a regular 1x1 conv (just a per-pixel linear transform),
6776        // but with channels transposed:
6777        // weight shape: [in_channels=1, out_channels=2, 1, 1]
6778        // input: [1, 1, 2, 2]
6779        // Each output channel c gets: input * weight[0, c, 0, 0]
6780        let weight_data = vec![3.0f32, 7.0]; // [1, 2, 1, 1]
6781        let conv = ConvTranspose2d {
6782            weight: Parameter::from_slice(&weight_data, &[1, 2, 1, 1]).unwrap(),
6783            bias: None,
6784            in_channels: 1,
6785            out_channels: 2,
6786            kernel_size: (1, 1),
6787            stride: (1, 1),
6788            padding: (0, 0),
6789            output_padding: (0, 0),
6790            dilation: (1, 1),
6791            groups: 1,
6792            training: false,
6793        };
6794
6795        let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
6796        let output = conv.forward(&input).unwrap();
6797        assert_eq!(output.shape(), &[1, 2, 2, 2]);
6798
6799        // out_ch 0: input * 3 = [3, 6, 9, 12]
6800        // out_ch 1: input * 7 = [7, 14, 21, 28]
6801        assert_close(
6802            output.data().unwrap(),
6803            &[3.0, 6.0, 9.0, 12.0, 7.0, 14.0, 21.0, 28.0],
6804            1e-5,
6805        );
6806    }
6807
6808    // -----------------------------------------------------------------------
6809    // ConvTranspose2d: correctness with stride insertion
6810    // -----------------------------------------------------------------------
6811
6812    #[test]
6813    fn test_conv_transpose2d_stride2_correctness() {
6814        // Input: [1, 1, 2, 2] = [[1, 2], [3, 4]]
6815        // Kernel: [1, 1, 2, 2] = [[1, 1], [1, 1]]  (all ones)
6816        // Stride=2, padding=0, output_padding=0
6817        // H_out = (2-1)*2 + 2 = 4, W_out = 4
6818        //
6819        // Stride insertion produces 3x3:
6820        //   [[1, 0, 2],
6821        //    [0, 0, 0],
6822        //    [3, 0, 4]]
6823        //
6824        // Flipped kernel (all ones, still all ones): [[1,1],[1,1]]
6825        // Internal conv with pad = kernel-1 = 1, stride=1 on 3x3:
6826        // Padded to 5x5:
6827        //   [[0, 0, 0, 0, 0],
6828        //    [0, 1, 0, 2, 0],
6829        //    [0, 0, 0, 0, 0],
6830        //    [0, 3, 0, 4, 0],
6831        //    [0, 0, 0, 0, 0]]
6832        // Convolve with 2x2 all-ones kernel, output 4x4:
6833        //   row 0: [1, 0+1, 2+0, 2] = [1, 1, 2, 2]
6834        //   row 1: [0+1, 1+0+0+0, 0+2+0+0, 0+2] = [1, 1, 2, 2]
6835        //   row 2: [3, 0+3, 4+0, 4] = [3, 3, 4, 4]
6836        //   row 3: [3, 3, 4, 4]
6837        //
6838        // Wait, let me recalculate more carefully.
6839        // After padding, we convolve (sum of 2x2 window at each position):
6840        // pos(0,0): 0+0+0+1 = 1
6841        // pos(0,1): 0+0+1+0 = 1
6842        // pos(0,2): 0+0+0+2 = 2
6843        // pos(0,3): 0+0+2+0 = 2
6844        // pos(1,0): 0+1+0+0 = 1
6845        // pos(1,1): 1+0+0+0 = 1
6846        // pos(1,2): 0+2+0+0 = 2
6847        // pos(1,3): 2+0+0+0 = 2
6848        // pos(2,0): 0+0+0+3 = 3
6849        // pos(2,1): 0+0+3+0 = 3
6850        // pos(2,2): 0+0+0+4 = 4
6851        // pos(2,3): 0+0+4+0 = 4
6852        // pos(3,0): 0+3+0+0 = 3
6853        // pos(3,1): 3+0+0+0 = 3
6854        // pos(3,2): 0+4+0+0 = 4
6855        // pos(3,3): 4+0+0+0 = 4
6856
6857        let weight_data = vec![1.0f32; 4]; // [1, 1, 2, 2]
6858        let conv = ConvTranspose2d {
6859            weight: Parameter::from_slice(&weight_data, &[1, 1, 2, 2]).unwrap(),
6860            bias: None,
6861            in_channels: 1,
6862            out_channels: 1,
6863            kernel_size: (2, 2),
6864            stride: (2, 2),
6865            padding: (0, 0),
6866            output_padding: (0, 0),
6867            dilation: (1, 1),
6868            groups: 1,
6869            training: false,
6870        };
6871
6872        let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
6873        let output = conv.forward(&input).unwrap();
6874        assert_eq!(output.shape(), &[1, 1, 4, 4]);
6875
6876        #[rustfmt::skip]
6877        let expected = [
6878            1.0, 1.0, 2.0, 2.0,
6879            1.0, 1.0, 2.0, 2.0,
6880            3.0, 3.0, 4.0, 4.0,
6881            3.0, 3.0, 4.0, 4.0,
6882        ];
6883        assert_close(output.data().unwrap(), &expected, 1e-5);
6884    }
6885
6886    // -----------------------------------------------------------------------
6887    // ConvTranspose2d: bias
6888    // -----------------------------------------------------------------------
6889
6890    #[test]
6891    fn test_conv_transpose2d_bias() {
6892        let weight_data = vec![1.0f32]; // [1, 1, 1, 1] identity
6893        let bias_data = vec![5.0f32];
6894        let conv = ConvTranspose2d {
6895            weight: Parameter::from_slice(&weight_data, &[1, 1, 1, 1]).unwrap(),
6896            bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
6897            in_channels: 1,
6898            out_channels: 1,
6899            kernel_size: (1, 1),
6900            stride: (1, 1),
6901            padding: (0, 0),
6902            output_padding: (0, 0),
6903            dilation: (1, 1),
6904            groups: 1,
6905            training: false,
6906        };
6907
6908        let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
6909        let output = conv.forward(&input).unwrap();
6910        assert_close(output.data().unwrap(), &[6.0, 7.0, 8.0, 9.0], 1e-5);
6911    }
6912
6913    // -----------------------------------------------------------------------
6914    // ConvTranspose2d: edge cases
6915    // -----------------------------------------------------------------------
6916
6917    #[test]
6918    fn test_conv_transpose2d_invalid_ndim() {
6919        let conv =
6920            ConvTranspose2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
6921        // Rank 3 `(C, H, W)` is now a VALID unbatched input (#1609); rank 2 is
6922        // not a recognised ConvTranspose2d input shape (neither batched rank 4
6923        // nor unbatched rank 3), so it must error.
6924        let input = t(&[1.0, 2.0, 3.0], &[1, 3]);
6925        assert!(conv.forward(&input).is_err());
6926    }
6927
6928    #[test]
6929    fn test_conv_transpose2d_channel_mismatch() {
6930        let conv =
6931            ConvTranspose2d::<f32>::new(3, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
6932        let input = t(&[0.0; 5 * 5], &[1, 1, 5, 5]);
6933        assert!(conv.forward(&input).is_err());
6934    }
6935
6936    #[test]
6937    fn test_conv_transpose2d_zero_channels_rejected() {
6938        assert!(ConvTranspose2d::<f32>::new(0, 1, (3, 3), (1, 1), (0, 0), (0, 0), false).is_err());
6939        assert!(ConvTranspose2d::<f32>::new(1, 0, (3, 3), (1, 1), (0, 0), (0, 0), false).is_err());
6940    }
6941
6942    #[test]
6943    fn test_conv_transpose2d_output_padding_too_large() {
6944        // output_padding must be < stride
6945        assert!(ConvTranspose2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), (2, 2), false).is_err());
6946    }
6947
6948    #[test]
6949    fn test_conv_transpose2d_parameter_count() {
6950        let conv =
6951            ConvTranspose2d::<f32>::new(8, 16, (3, 3), (2, 2), (1, 1), (0, 0), true).unwrap();
6952        // weight: 8 * 16 * 3 * 3 = 1152, bias: 16, total: 1168
6953        assert_eq!(conv.num_parameters(), 1168);
6954        assert_eq!(conv.parameters().len(), 2);
6955    }
6956
6957    // ===================================================================
6958    // Conv3d tests
6959    // ===================================================================
6960
6961    // -----------------------------------------------------------------------
6962    // Conv3d: output shape
6963    // -----------------------------------------------------------------------
6964
6965    #[test]
6966    fn test_conv3d_output_shape_no_padding() {
6967        // Input: [1, 1, 5, 5, 5], kernel 3x3x3, stride 1, padding 0
6968        // D_out = (5 - 3) / 1 + 1 = 3
6969        let conv = Conv3d::<f32>::new(1, 4, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
6970        let input = t(&vec![0.0; 5 * 5 * 5], &[1, 1, 5, 5, 5]);
6971        let output = conv.forward(&input).unwrap();
6972        assert_eq!(output.shape(), &[1, 4, 3, 3, 3]);
6973    }
6974
6975    #[test]
6976    fn test_conv3d_output_shape_with_padding() {
6977        // Input: [2, 3, 8, 8, 8], kernel 3x3x3, stride 1, padding 1
6978        // D_out = (8 + 2 - 3) / 1 + 1 = 8
6979        let conv = Conv3d::<f32>::new(3, 16, (3, 3, 3), (1, 1, 1), (1, 1, 1), true).unwrap();
6980        let input = t(&vec![0.0; 2 * 3 * 8 * 8 * 8], &[2, 3, 8, 8, 8]);
6981        let output = conv.forward(&input).unwrap();
6982        assert_eq!(output.shape(), &[2, 16, 8, 8, 8]);
6983    }
6984
6985    #[test]
6986    fn test_conv3d_output_shape_with_stride() {
6987        // Input: [1, 1, 6, 6, 6], kernel 3x3x3, stride 2, padding 0
6988        // D_out = (6 - 3) / 2 + 1 = 2
6989        let conv = Conv3d::<f32>::new(1, 4, (3, 3, 3), (2, 2, 2), (0, 0, 0), false).unwrap();
6990        let input = t(&vec![0.0; 6 * 6 * 6], &[1, 1, 6, 6, 6]);
6991        let output = conv.forward(&input).unwrap();
6992        assert_eq!(output.shape(), &[1, 4, 2, 2, 2]);
6993    }
6994
6995    // -----------------------------------------------------------------------
6996    // Conv3d: 1x1x1 kernel correctness
6997    // -----------------------------------------------------------------------
6998
6999    #[test]
7000    fn test_conv3d_1x1x1_kernel_correctness() {
7001        // weight: [2, 1, 1, 1, 1] = [3.0, 5.0]
7002        // input:  [1, 1, 2, 1, 1] = [1.0, 2.0]
7003        // output: [1, 2, 2, 1, 1]
7004        //   out_ch 0: [3.0, 6.0]
7005        //   out_ch 1: [5.0, 10.0]
7006        let weight_data = vec![3.0f32, 5.0];
7007        let conv = Conv3d {
7008            weight: Parameter::from_slice(&weight_data, &[2, 1, 1, 1, 1]).unwrap(),
7009            bias: None,
7010            in_channels: 1,
7011            out_channels: 2,
7012            kernel_size: (1, 1, 1),
7013            stride: (1, 1, 1),
7014            padding: (0, 0, 0),
7015            dilation: (1, 1, 1),
7016            groups: 1,
7017            padding_mode: crate::padding::PaddingMode::Zeros,
7018            string_padding: None,
7019            training: false,
7020        };
7021
7022        let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
7023        let output = conv.forward(&input).unwrap();
7024        assert_eq!(output.shape(), &[1, 2, 2, 1, 1]);
7025        assert_close(output.data().unwrap(), &[3.0, 6.0, 5.0, 10.0], 1e-5);
7026    }
7027
7028    // -----------------------------------------------------------------------
7029    // Conv3d: forward correctness with a 3x3x3 kernel
7030    // -----------------------------------------------------------------------
7031
7032    #[test]
7033    fn test_conv3d_3x3x3_kernel_forward() {
7034        // Input: [1, 1, 3, 3, 3] (all ones), kernel: [1, 1, 3, 3, 3] (all ones)
7035        // Output: [1, 1, 1, 1, 1] = sum of 27 ones = 27.0
7036        let input_data = vec![1.0f32; 27];
7037        let weight_data = vec![1.0f32; 27];
7038        let conv = Conv3d {
7039            weight: Parameter::from_slice(&weight_data, &[1, 1, 3, 3, 3]).unwrap(),
7040            bias: None,
7041            in_channels: 1,
7042            out_channels: 1,
7043            kernel_size: (3, 3, 3),
7044            stride: (1, 1, 1),
7045            padding: (0, 0, 0),
7046            dilation: (1, 1, 1),
7047            groups: 1,
7048            padding_mode: crate::padding::PaddingMode::Zeros,
7049            string_padding: None,
7050            training: false,
7051        };
7052
7053        let input = t(&input_data, &[1, 1, 3, 3, 3]);
7054        let output = conv.forward(&input).unwrap();
7055        assert_eq!(output.shape(), &[1, 1, 1, 1, 1]);
7056        assert_close(output.data().unwrap(), &[27.0], 1e-5);
7057    }
7058
7059    // -----------------------------------------------------------------------
7060    // Conv3d: bias
7061    // -----------------------------------------------------------------------
7062
7063    #[test]
7064    fn test_conv3d_bias() {
7065        let conv = Conv3d {
7066            weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1, 1, 1]).unwrap(),
7067            bias: Some(Parameter::from_slice(&[10.0f32], &[1]).unwrap()),
7068            in_channels: 1,
7069            out_channels: 1,
7070            kernel_size: (1, 1, 1),
7071            stride: (1, 1, 1),
7072            padding: (0, 0, 0),
7073            dilation: (1, 1, 1),
7074            groups: 1,
7075            padding_mode: crate::padding::PaddingMode::Zeros,
7076            string_padding: None,
7077            training: false,
7078        };
7079
7080        let input = t(&[2.0, 3.0], &[1, 1, 2, 1, 1]);
7081        let output = conv.forward(&input).unwrap();
7082        assert_close(output.data().unwrap(), &[12.0, 13.0], 1e-5);
7083    }
7084
7085    // -----------------------------------------------------------------------
7086    // Conv3d: backward produces correct shapes
7087    // -----------------------------------------------------------------------
7088
7089    #[test]
7090    fn test_conv3d_backward_produces_correct_shapes() {
7091        let weight_data = vec![1.0f32; 2 * 3 * 3 * 3]; // [2, 1, 3, 3, 3]
7092        let input_data = vec![1.0f32; 5 * 5 * 5]; // [1, 1, 5, 5, 5]
7093        let bias_data = vec![0.0f32; 2];
7094
7095        let conv = Conv3d {
7096            weight: Parameter::from_slice(&weight_data, &[2, 1, 3, 3, 3]).unwrap(),
7097            bias: Some(Parameter::from_slice(&bias_data, &[2]).unwrap()),
7098            in_channels: 1,
7099            out_channels: 2,
7100            kernel_size: (3, 3, 3),
7101            stride: (1, 1, 1),
7102            padding: (0, 0, 0),
7103            dilation: (1, 1, 1),
7104            groups: 1,
7105            padding_mode: crate::padding::PaddingMode::Zeros,
7106            string_padding: None,
7107            training: false,
7108        };
7109
7110        let input = leaf(&input_data, &[1, 1, 5, 5, 5]);
7111        let output = conv.forward(&input).unwrap();
7112        assert_eq!(output.shape(), &[1, 2, 3, 3, 3]);
7113        assert!(output.grad_fn().is_some());
7114        assert_eq!(output.grad_fn().unwrap().name(), "Conv3dBackward");
7115
7116        let grad_output = t(&vec![1.0; 2 * 3 * 3 * 3], &[1, 2, 3, 3, 3]);
7117        let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
7118
7119        assert!(grads[0].is_some());
7120        assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 5, 5, 5]);
7121        assert!(grads[1].is_some());
7122        assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 3, 3, 3]);
7123        assert!(grads[2].is_some());
7124        assert_eq!(grads[2].as_ref().unwrap().shape(), &[2]);
7125    }
7126
7127    // -----------------------------------------------------------------------
7128    // Conv3d: edge cases
7129    // -----------------------------------------------------------------------
7130
7131    #[test]
7132    fn test_conv3d_invalid_ndim() {
7133        let conv = Conv3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
7134        let input = t(&[0.0; 25], &[1, 1, 5, 5]);
7135        assert!(conv.forward(&input).is_err());
7136    }
7137
7138    #[test]
7139    fn test_conv3d_channel_mismatch() {
7140        let conv = Conv3d::<f32>::new(3, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).unwrap();
7141        let input = t(&vec![0.0; 5 * 5 * 5], &[1, 1, 5, 5, 5]);
7142        assert!(conv.forward(&input).is_err());
7143    }
7144
7145    #[test]
7146    fn test_conv3d_zero_channels_rejected() {
7147        assert!(Conv3d::<f32>::new(0, 16, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
7148        assert!(Conv3d::<f32>::new(3, 0, (3, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
7149    }
7150
7151    #[test]
7152    fn test_conv3d_zero_kernel_rejected() {
7153        assert!(Conv3d::<f32>::new(1, 1, (0, 3, 3), (1, 1, 1), (0, 0, 0), false).is_err());
7154    }
7155
7156    #[test]
7157    fn test_conv3d_zero_stride_rejected() {
7158        assert!(Conv3d::<f32>::new(1, 1, (3, 3, 3), (0, 1, 1), (0, 0, 0), false).is_err());
7159    }
7160
7161    #[test]
7162    fn test_conv3d_parameter_count() {
7163        let conv = Conv3d::<f32>::new(3, 8, (3, 3, 3), (1, 1, 1), (0, 0, 0), true).unwrap();
7164        // weight: 8 * 3 * 3 * 3 * 3 = 648, bias: 8, total: 656
7165        assert_eq!(conv.num_parameters(), 656);
7166        assert_eq!(conv.parameters().len(), 2);
7167    }
7168
7169    #[test]
7170    fn test_conv3d_named_parameters() {
7171        let conv = Conv3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), true).unwrap();
7172        let named = conv.named_parameters();
7173        assert_eq!(named.len(), 2);
7174        assert_eq!(named[0].0, "weight");
7175        assert_eq!(named[1].0, "bias");
7176    }
7177
7178    // ===================================================================
7179    // ConvTranspose1d tests
7180    // ===================================================================
7181
7182    // -----------------------------------------------------------------------
7183    // ConvTranspose1d: output shape
7184    // -----------------------------------------------------------------------
7185
7186    #[test]
7187    fn test_conv_transpose1d_output_shape_basic() {
7188        // Input: [1, 1, 5], kernel 3, stride 1, padding 0, output_padding 0
7189        // L_out = (5 - 1) * 1 - 0 + 3 + 0 = 7
7190        let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 1, 0, 0, false).unwrap();
7191        let input = t(&[0.0; 5], &[1, 1, 5]);
7192        let output = conv.forward(&input).unwrap();
7193        assert_eq!(output.shape(), &[1, 1, 7]);
7194    }
7195
7196    #[test]
7197    fn test_conv_transpose1d_output_shape_stride2() {
7198        // Input: [1, 1, 3], kernel 3, stride 2, padding 0, output_padding 0
7199        // L_out = (3 - 1) * 2 - 0 + 3 + 0 = 7
7200        let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 0, 0, false).unwrap();
7201        let input = t(&[0.0; 3], &[1, 1, 3]);
7202        let output = conv.forward(&input).unwrap();
7203        assert_eq!(output.shape(), &[1, 1, 7]);
7204    }
7205
7206    #[test]
7207    fn test_conv_transpose1d_output_shape_with_padding() {
7208        // Input: [1, 1, 5], kernel 3, stride 2, padding 1, output_padding 0
7209        // L_out = (5 - 1) * 2 - 2 + 3 + 0 = 9
7210        let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 1, 0, false).unwrap();
7211        let input = t(&[0.0; 5], &[1, 1, 5]);
7212        let output = conv.forward(&input).unwrap();
7213        assert_eq!(output.shape(), &[1, 1, 9]);
7214    }
7215
7216    #[test]
7217    fn test_conv_transpose1d_output_shape_with_output_padding() {
7218        // Input: [1, 1, 5], kernel 3, stride 2, padding 1, output_padding 1
7219        // L_out = (5 - 1) * 2 - 2 + 3 + 1 = 10
7220        let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 2, 1, 1, false).unwrap();
7221        let input = t(&[0.0; 5], &[1, 1, 5]);
7222        let output = conv.forward(&input).unwrap();
7223        assert_eq!(output.shape(), &[1, 1, 10]);
7224    }
7225
7226    // -----------------------------------------------------------------------
7227    // ConvTranspose1d: 1x1 kernel correctness
7228    // -----------------------------------------------------------------------
7229
7230    #[test]
7231    fn test_conv_transpose1d_1x1_kernel() {
7232        // With a kernel_size=1, stride 1, no padding, the transposed conv is
7233        // a per-position linear transform with channels transposed.
7234        // weight shape: [1, 2, 1] (in_channels=1, out_channels=2, k=1)
7235        let weight_data = vec![3.0f32, 7.0]; // [1, 2, 1]
7236        let conv = ConvTranspose1d {
7237            weight: Parameter::from_slice(&weight_data, &[1, 2, 1]).unwrap(),
7238            bias: None,
7239            in_channels: 1,
7240            out_channels: 2,
7241            kernel_size: 1,
7242            stride: 1,
7243            padding: 0,
7244            output_padding: 0,
7245            dilation: 1,
7246            groups: 1,
7247            training: false,
7248        };
7249
7250        let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
7251        let output = conv.forward(&input).unwrap();
7252        assert_eq!(output.shape(), &[1, 2, 3]);
7253
7254        // out_ch 0: input * 3 = [3, 6, 9]
7255        // out_ch 1: input * 7 = [7, 14, 21]
7256        assert_close(
7257            output.data().unwrap(),
7258            &[3.0, 6.0, 9.0, 7.0, 14.0, 21.0],
7259            1e-5,
7260        );
7261    }
7262
7263    // -----------------------------------------------------------------------
7264    // ConvTranspose1d: stride=2 correctness
7265    // -----------------------------------------------------------------------
7266
7267    #[test]
7268    fn test_conv_transpose1d_stride2_correctness() {
7269        // Input: [1, 1, 2] = [1, 2]
7270        // Kernel: [1, 1, 2] = [1, 1] (all ones)
7271        // Stride=2, padding=0, output_padding=0
7272        // L_out = (2-1)*2 + 2 = 4
7273        //
7274        // Stride insertion produces [1, 0, 2]
7275        // Flipped kernel (all ones): [1, 1]
7276        // Internal conv with pad = 2-1 = 1, stride=1 on [1, 0, 2]:
7277        // Padded to [0, 1, 0, 2, 0]
7278        // Convolve with [1, 1] kernel, output 4:
7279        //   pos 0: 0+1 = 1
7280        //   pos 1: 1+0 = 1
7281        //   pos 2: 0+2 = 2
7282        //   pos 3: 2+0 = 2
7283        let weight_data = vec![1.0f32; 2]; // [1, 1, 2]
7284        let conv = ConvTranspose1d {
7285            weight: Parameter::from_slice(&weight_data, &[1, 1, 2]).unwrap(),
7286            bias: None,
7287            in_channels: 1,
7288            out_channels: 1,
7289            kernel_size: 2,
7290            stride: 2,
7291            padding: 0,
7292            output_padding: 0,
7293            dilation: 1,
7294            groups: 1,
7295            training: false,
7296        };
7297
7298        let input = t(&[1.0, 2.0], &[1, 1, 2]);
7299        let output = conv.forward(&input).unwrap();
7300        assert_eq!(output.shape(), &[1, 1, 4]);
7301        assert_close(output.data().unwrap(), &[1.0, 1.0, 2.0, 2.0], 1e-5);
7302    }
7303
7304    // -----------------------------------------------------------------------
7305    // ConvTranspose1d: bias
7306    // -----------------------------------------------------------------------
7307
7308    #[test]
7309    fn test_conv_transpose1d_bias() {
7310        let conv = ConvTranspose1d {
7311            weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1]).unwrap(),
7312            bias: Some(Parameter::from_slice(&[5.0f32], &[1]).unwrap()),
7313            in_channels: 1,
7314            out_channels: 1,
7315            kernel_size: 1,
7316            stride: 1,
7317            padding: 0,
7318            output_padding: 0,
7319            dilation: 1,
7320            groups: 1,
7321            training: false,
7322        };
7323
7324        let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
7325        let output = conv.forward(&input).unwrap();
7326        assert_close(output.data().unwrap(), &[6.0, 7.0, 8.0], 1e-5);
7327    }
7328
7329    // -----------------------------------------------------------------------
7330    // ConvTranspose1d: backward produces gradients
7331    // -----------------------------------------------------------------------
7332
7333    #[test]
7334    fn test_conv_transpose1d_backward_produces_gradients() {
7335        let weight_data = vec![1.0f32; 3]; // [1, 1, 3]
7336        let bias_data = vec![0.0f32; 1];
7337
7338        let conv = ConvTranspose1d {
7339            weight: Parameter::from_slice(&weight_data, &[1, 1, 3]).unwrap(),
7340            bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
7341            in_channels: 1,
7342            out_channels: 1,
7343            kernel_size: 3,
7344            stride: 1,
7345            padding: 0,
7346            output_padding: 0,
7347            dilation: 1,
7348            groups: 1,
7349            training: false,
7350        };
7351
7352        let input = leaf(&[1.0f32, 2.0, 3.0], &[1, 1, 3]);
7353        let output = conv.forward(&input).unwrap();
7354        // L_out = (3 - 1) * 1 - 0 + 3 + 0 = 5
7355        assert_eq!(output.shape(), &[1, 1, 5]);
7356        assert!(output.grad_fn().is_some());
7357        assert_eq!(output.grad_fn().unwrap().name(), "ConvTranspose1dBackward");
7358
7359        let grad_output = t(&[1.0; 5], &[1, 1, 5]);
7360        let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
7361
7362        // grad_input shape: [1, 1, 3]
7363        assert!(grads[0].is_some());
7364        assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 3]);
7365        // grad_weight shape: [1, 1, 3]
7366        assert!(grads[1].is_some());
7367        assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 3]);
7368        // grad_bias shape: [1]
7369        assert!(grads[2].is_some());
7370        assert_eq!(grads[2].as_ref().unwrap().shape(), &[1]);
7371    }
7372
7373    // -----------------------------------------------------------------------
7374    // ConvTranspose1d: edge cases
7375    // -----------------------------------------------------------------------
7376
7377    #[test]
7378    fn test_conv_transpose1d_invalid_ndim() {
7379        let conv = ConvTranspose1d::<f32>::new(1, 1, 3, 1, 0, 0, false).unwrap();
7380        let input = t(&[0.0; 4], &[1, 1, 2, 2]);
7381        assert!(conv.forward(&input).is_err());
7382    }
7383
7384    #[test]
7385    fn test_conv_transpose1d_channel_mismatch() {
7386        let conv = ConvTranspose1d::<f32>::new(3, 1, 3, 1, 0, 0, false).unwrap();
7387        let input = t(&[0.0; 10], &[1, 1, 10]);
7388        assert!(conv.forward(&input).is_err());
7389    }
7390
7391    #[test]
7392    fn test_conv_transpose1d_zero_channels_rejected() {
7393        assert!(ConvTranspose1d::<f32>::new(0, 1, 3, 1, 0, 0, false).is_err());
7394        assert!(ConvTranspose1d::<f32>::new(1, 0, 3, 1, 0, 0, false).is_err());
7395    }
7396
7397    #[test]
7398    fn test_conv_transpose1d_output_padding_too_large() {
7399        assert!(ConvTranspose1d::<f32>::new(1, 1, 3, 2, 0, 2, false).is_err());
7400    }
7401
7402    #[test]
7403    fn test_conv_transpose1d_parameter_count() {
7404        let conv = ConvTranspose1d::<f32>::new(8, 16, 5, 2, 1, 0, true).unwrap();
7405        // weight: 8 * 16 * 5 = 640, bias: 16, total: 656
7406        assert_eq!(conv.num_parameters(), 656);
7407        assert_eq!(conv.parameters().len(), 2);
7408    }
7409
7410    // ===================================================================
7411    // ConvTranspose3d tests
7412    // ===================================================================
7413
7414    // -----------------------------------------------------------------------
7415    // ConvTranspose3d: output shape
7416    // -----------------------------------------------------------------------
7417
7418    #[test]
7419    fn test_conv_transpose3d_output_shape_basic() {
7420        // Input: [1, 1, 3, 3, 3], kernel 3x3x3, stride 1, padding 0, output_padding 0
7421        // D_out = (3 - 1) * 1 - 0 + 3 + 0 = 5
7422        let conv =
7423            ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7424                .unwrap();
7425        let input = t(&[0.0; 27], &[1, 1, 3, 3, 3]);
7426        let output = conv.forward(&input).unwrap();
7427        assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
7428    }
7429
7430    #[test]
7431    fn test_conv_transpose3d_output_shape_stride2() {
7432        // Input: [1, 1, 2, 2, 2], kernel 3x3x3, stride 2, padding 0, output_padding 0
7433        // D_out = (2 - 1) * 2 - 0 + 3 + 0 = 5
7434        let conv =
7435            ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (0, 0, 0), (0, 0, 0), false)
7436                .unwrap();
7437        let input = t(&[0.0; 8], &[1, 1, 2, 2, 2]);
7438        let output = conv.forward(&input).unwrap();
7439        assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
7440    }
7441
7442    #[test]
7443    fn test_conv_transpose3d_output_shape_with_padding() {
7444        // Input: [1, 1, 3, 3, 3], kernel 3x3x3, stride 2, padding 1, output_padding 0
7445        // D_out = (3 - 1) * 2 - 2 + 3 + 0 = 5
7446        let conv =
7447            ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0), false)
7448                .unwrap();
7449        let input = t(&[0.0; 27], &[1, 1, 3, 3, 3]);
7450        let output = conv.forward(&input).unwrap();
7451        assert_eq!(output.shape(), &[1, 1, 5, 5, 5]);
7452    }
7453
7454    #[test]
7455    fn test_conv_transpose3d_output_shape_with_output_padding() {
7456        // Input: [1, 1, 3, 3, 3], kernel 3x3x3, stride 2, padding 1, output_padding 1
7457        // D_out = (3 - 1) * 2 - 2 + 3 + 1 = 6
7458        let conv =
7459            ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (1, 1, 1), (1, 1, 1), false)
7460                .unwrap();
7461        let input = t(&[0.0; 27], &[1, 1, 3, 3, 3]);
7462        let output = conv.forward(&input).unwrap();
7463        assert_eq!(output.shape(), &[1, 1, 6, 6, 6]);
7464    }
7465
7466    // -----------------------------------------------------------------------
7467    // ConvTranspose3d: stride=2 upsamples (doubles spatial dims)
7468    // -----------------------------------------------------------------------
7469
7470    #[test]
7471    fn test_conv_transpose3d_stride2_upsamples() {
7472        // With stride=2, kernel=2x2x2, padding=0, output_padding=0:
7473        // D_out = (D - 1) * 2 + 2 = 2 * D
7474        let conv =
7475            ConvTranspose3d::<f32>::new(1, 1, (2, 2, 2), (2, 2, 2), (0, 0, 0), (0, 0, 0), false)
7476                .unwrap();
7477        let input = t(&vec![0.0; 4 * 4 * 4], &[1, 1, 4, 4, 4]);
7478        let output = conv.forward(&input).unwrap();
7479        assert_eq!(output.shape(), &[1, 1, 8, 8, 8]);
7480    }
7481
7482    // -----------------------------------------------------------------------
7483    // ConvTranspose3d: 1x1x1 kernel correctness
7484    // -----------------------------------------------------------------------
7485
7486    #[test]
7487    fn test_conv_transpose3d_1x1x1_kernel() {
7488        // weight shape: [in=1, out=2, 1, 1, 1]
7489        let weight_data = vec![3.0f32, 7.0]; // [1, 2, 1, 1, 1]
7490        let conv = ConvTranspose3d {
7491            weight: Parameter::from_slice(&weight_data, &[1, 2, 1, 1, 1]).unwrap(),
7492            bias: None,
7493            in_channels: 1,
7494            out_channels: 2,
7495            kernel_size: (1, 1, 1),
7496            stride: (1, 1, 1),
7497            padding: (0, 0, 0),
7498            output_padding: (0, 0, 0),
7499            dilation: (1, 1, 1),
7500            groups: 1,
7501            training: false,
7502        };
7503
7504        let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
7505        let output = conv.forward(&input).unwrap();
7506        assert_eq!(output.shape(), &[1, 2, 2, 1, 1]);
7507        assert_close(output.data().unwrap(), &[3.0, 6.0, 7.0, 14.0], 1e-5);
7508    }
7509
7510    // -----------------------------------------------------------------------
7511    // ConvTranspose3d: bias
7512    // -----------------------------------------------------------------------
7513
7514    #[test]
7515    fn test_conv_transpose3d_bias() {
7516        let conv = ConvTranspose3d {
7517            weight: Parameter::from_slice(&[1.0f32], &[1, 1, 1, 1, 1]).unwrap(),
7518            bias: Some(Parameter::from_slice(&[5.0f32], &[1]).unwrap()),
7519            in_channels: 1,
7520            out_channels: 1,
7521            kernel_size: (1, 1, 1),
7522            stride: (1, 1, 1),
7523            padding: (0, 0, 0),
7524            output_padding: (0, 0, 0),
7525            dilation: (1, 1, 1),
7526            groups: 1,
7527            training: false,
7528        };
7529
7530        let input = t(&[1.0, 2.0], &[1, 1, 2, 1, 1]);
7531        let output = conv.forward(&input).unwrap();
7532        assert_close(output.data().unwrap(), &[6.0, 7.0], 1e-5);
7533    }
7534
7535    // -----------------------------------------------------------------------
7536    // ConvTranspose3d: backward produces gradients
7537    // -----------------------------------------------------------------------
7538
7539    #[test]
7540    fn test_conv_transpose3d_backward_produces_gradients() {
7541        let weight_data = vec![1.0f32; 2 * 2 * 2]; // [1, 1, 2, 2, 2]
7542        let bias_data = vec![0.0f32; 1];
7543
7544        let conv = ConvTranspose3d {
7545            weight: Parameter::from_slice(&weight_data, &[1, 1, 2, 2, 2]).unwrap(),
7546            bias: Some(Parameter::from_slice(&bias_data, &[1]).unwrap()),
7547            in_channels: 1,
7548            out_channels: 1,
7549            kernel_size: (2, 2, 2),
7550            stride: (1, 1, 1),
7551            padding: (0, 0, 0),
7552            output_padding: (0, 0, 0),
7553            dilation: (1, 1, 1),
7554            groups: 1,
7555            training: false,
7556        };
7557
7558        // D_out = (2-1)*1 - 0 + 2 + 0 = 3
7559        let input = leaf(&[1.0f32; 8], &[1, 1, 2, 2, 2]);
7560        let output = conv.forward(&input).unwrap();
7561        assert_eq!(output.shape(), &[1, 1, 3, 3, 3]);
7562        assert!(output.grad_fn().is_some());
7563        assert_eq!(output.grad_fn().unwrap().name(), "ConvTranspose3dBackward");
7564
7565        let grad_output = t(&[1.0; 27], &[1, 1, 3, 3, 3]);
7566        let grads = output.grad_fn().unwrap().backward(&grad_output).unwrap();
7567
7568        assert!(grads[0].is_some());
7569        assert_eq!(grads[0].as_ref().unwrap().shape(), &[1, 1, 2, 2, 2]);
7570        assert!(grads[1].is_some());
7571        assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 2, 2, 2]);
7572        assert!(grads[2].is_some());
7573        assert_eq!(grads[2].as_ref().unwrap().shape(), &[1]);
7574    }
7575
7576    // -----------------------------------------------------------------------
7577    // ConvTranspose3d: edge cases
7578    // -----------------------------------------------------------------------
7579
7580    #[test]
7581    fn test_conv_transpose3d_invalid_ndim() {
7582        let conv =
7583            ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7584                .unwrap();
7585        // Rank 4 `(C, D, H, W)` is now a VALID unbatched input (#1609); rank 3
7586        // is neither batched (rank 5) nor unbatched (rank 4), so it must error.
7587        let input = t(&[0.0; 25], &[1, 5, 5]);
7588        assert!(conv.forward(&input).is_err());
7589    }
7590
7591    #[test]
7592    fn test_conv_transpose3d_channel_mismatch() {
7593        let conv =
7594            ConvTranspose3d::<f32>::new(3, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7595                .unwrap();
7596        let input = t(&vec![0.0; 5 * 5 * 5], &[1, 1, 5, 5, 5]);
7597        assert!(conv.forward(&input).is_err());
7598    }
7599
7600    #[test]
7601    fn test_conv_transpose3d_zero_channels_rejected() {
7602        assert!(
7603            ConvTranspose3d::<f32>::new(0, 1, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7604                .is_err()
7605        );
7606        assert!(
7607            ConvTranspose3d::<f32>::new(1, 0, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
7608                .is_err()
7609        );
7610    }
7611
7612    #[test]
7613    fn test_conv_transpose3d_output_padding_too_large() {
7614        assert!(
7615            ConvTranspose3d::<f32>::new(1, 1, (3, 3, 3), (2, 2, 2), (0, 0, 0), (2, 2, 2), false)
7616                .is_err()
7617        );
7618    }
7619
7620    #[test]
7621    fn test_conv_transpose3d_parameter_count() {
7622        let conv =
7623            ConvTranspose3d::<f32>::new(8, 16, (3, 3, 3), (2, 2, 2), (1, 1, 1), (0, 0, 0), true)
7624                .unwrap();
7625        // weight: 8 * 16 * 3 * 3 * 3 = 3456, bias: 16, total: 3472
7626        assert_eq!(conv.num_parameters(), 3472);
7627        assert_eq!(conv.parameters().len(), 2);
7628    }
7629
7630    // =======================================================================
7631    // ConvTranspose groups (#1607) + dilation (#1608) + unbatched (#1609)
7632    //
7633    // All expected values are derived from a LIVE PyTorch 2.11.0 oracle
7634    // (R-CHAR-3): `torch.nn.functional.conv_transpose{1,2,3}d(...)` forward
7635    // outputs and `x.grad` / `w.grad` / `b.grad` after `y.sum().backward()`
7636    // (grad_output = ones), with the exact deterministic weights/inputs
7637    // reproduced below. The transposed weight layout is `[in, out/groups, *k]`
7638    // (`torch/nn/modules/conv.py:164`); grad_weight comes back in that same
7639    // `[in, out/groups, *k]` layout (verified against the oracle). The per-group
7640    // partition mirrors `aten/src/ATen/native/Convolution.cpp:1723-1729`. No
7641    // tautological self-reference. Oracle script lives in the commit body.
7642    // =======================================================================
7643
7644    /// Build a grouped/dilated ConvTranspose1d through the production
7645    /// `new_full` constructor, overwriting weight/bias with caller-supplied
7646    /// deterministic tensors. Weight must be `[in, out/groups, k]`.
7647    #[allow(clippy::too_many_arguments)]
7648    fn ct1d_full_fixed(
7649        in_c: usize,
7650        out_c: usize,
7651        k: usize,
7652        stride: usize,
7653        padding: usize,
7654        output_padding: usize,
7655        dilation: usize,
7656        groups: usize,
7657        weight: &[f32],
7658        bias: Option<&[f32]>,
7659    ) -> ConvTranspose1d<f32> {
7660        let mut ct = ConvTranspose1d::<f32>::new_full(
7661            in_c,
7662            out_c,
7663            k,
7664            stride,
7665            padding,
7666            output_padding,
7667            dilation,
7668            groups,
7669            bias.is_some(),
7670        )
7671        .unwrap();
7672        ct.weight = Parameter::from_slice(weight, &[in_c, out_c / groups, k]).unwrap();
7673        if let Some(bvals) = bias {
7674            ct.bias = Some(Parameter::from_slice(bvals, &[out_c]).unwrap());
7675        }
7676        ct
7677    }
7678
7679    /// Grouped ConvTranspose1d, groups=2. Forward + grad_x/grad_w/grad_b match
7680    /// the live torch 2.11 oracle. in=4 out=4 k=2 groups=2.
7681    #[test]
7682    fn test_conv_transpose1d_groups2_matches_torch() {
7683        let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.1).collect(); // [4,2,2]
7684        let bias = [0.5f32, -0.5, 0.25, -0.25];
7685        let ct = ct1d_full_fixed(4, 4, 2, 1, 0, 0, 1, 2, &weight, Some(&bias));
7686        let x = leaf(&(1..=20).map(|i| i as f32).collect::<Vec<_>>(), &[1, 4, 5]);
7687        let y = ct.forward(&x).unwrap();
7688        assert_eq!(y.shape(), &[1, 4, 6]);
7689        assert_close(
7690            y.data().unwrap(),
7691            &[
7692                3.6, 8.0, 9.4, 10.8, 12.2, 7.5, 4.0, 10.2, 12.4, 14.6, 16.8, 9.5, 30.95, 66.55,
7693                71.15, 75.75, 80.35, 43.25, 35.85, 77.25, 82.65, 88.05, 93.45, 49.75,
7694            ],
7695            1e-3,
7696        );
7697        let grads = ct
7698            .forward(&x)
7699            .unwrap()
7700            .grad_fn()
7701            .unwrap()
7702            .backward(&t(&[1.0f32; 24], &[1, 4, 6]))
7703            .unwrap();
7704        assert_close(
7705            grads[0].as_ref().unwrap().data().unwrap(),
7706            &[
7707                1.0, 1.0, 1.0, 1.0, 1.0, 2.6, 2.6, 2.6, 2.6, 2.6, 4.2, 4.2, 4.2, 4.2, 4.2, 5.8,
7708                5.8, 5.8, 5.8, 5.8,
7709            ],
7710            1e-4,
7711        );
7712        assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 2, 2]);
7713        assert_close(
7714            grads[1].as_ref().unwrap().data().unwrap(),
7715            &[
7716                15.0, 15.0, 15.0, 15.0, 40.0, 40.0, 40.0, 40.0, 65.0, 65.0, 65.0, 65.0, 90.0, 90.0,
7717                90.0, 90.0,
7718            ],
7719            1e-4,
7720        );
7721        assert_close(
7722            grads[2].as_ref().unwrap().data().unwrap(),
7723            &[6.0, 6.0, 6.0, 6.0],
7724            1e-4,
7725        );
7726    }
7727
7728    /// Depthwise ConvTranspose1d, groups=3, no bias. in=3 out=3 k=2.
7729    #[test]
7730    fn test_conv_transpose1d_groups3_depthwise_matches_torch() {
7731        let weight: Vec<f32> = (1..=6).map(|i| i as f32 * 0.5).collect(); // [3,1,2]
7732        let ct = ct1d_full_fixed(3, 3, 2, 1, 0, 0, 1, 3, &weight, None);
7733        let x = leaf(&(1..=15).map(|i| i as f32).collect::<Vec<_>>(), &[1, 3, 5]);
7734        let y = ct.forward(&x).unwrap();
7735        assert_eq!(y.shape(), &[1, 3, 6]);
7736        assert_close(
7737            y.data().unwrap(),
7738            &[
7739                0.5, 2.0, 3.5, 5.0, 6.5, 5.0, 9.0, 22.5, 26.0, 29.5, 33.0, 20.0, 27.5, 63.0, 68.5,
7740                74.0, 79.5, 45.0,
7741            ],
7742            1e-3,
7743        );
7744        let grads = ct
7745            .forward(&x)
7746            .unwrap()
7747            .grad_fn()
7748            .unwrap()
7749            .backward(&t(&[1.0f32; 18], &[1, 3, 6]))
7750            .unwrap();
7751        assert_close(
7752            grads[0].as_ref().unwrap().data().unwrap(),
7753            &[
7754                1.5, 1.5, 1.5, 1.5, 1.5, 3.5, 3.5, 3.5, 3.5, 3.5, 5.5, 5.5, 5.5, 5.5, 5.5,
7755            ],
7756            1e-4,
7757        );
7758        assert_eq!(grads[1].as_ref().unwrap().shape(), &[3, 1, 2]);
7759        assert_close(
7760            grads[1].as_ref().unwrap().data().unwrap(),
7761            &[15.0, 15.0, 40.0, 40.0, 65.0, 65.0],
7762            1e-4,
7763        );
7764    }
7765
7766    /// Dilated ConvTranspose1d, dilation=2, groups=1. in=2 out=2 k=3.
7767    #[test]
7768    fn test_conv_transpose1d_dilation2_matches_torch() {
7769        let weight: Vec<f32> = (1..=12).map(|i| i as f32 * 0.1).collect(); // [2,2,3]
7770        let bias = [1.0f32, -1.0];
7771        let ct = ct1d_full_fixed(2, 2, 3, 1, 0, 0, 2, 1, &weight, Some(&bias));
7772        let x = leaf(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), &[1, 2, 4]);
7773        let y = ct.forward(&x).unwrap();
7774        assert_eq!(y.shape(), &[1, 2, 8]);
7775        assert_close(
7776            y.data().unwrap(),
7777            &[
7778                4.6, 5.4, 10.4, 12.2, 12.0, 14.2, 8.2, 9.4, 4.4, 5.8, 13.2, 16.2, 14.8, 18.2, 9.2,
7779                11.0,
7780            ],
7781            1e-3,
7782        );
7783        let grads = ct
7784            .forward(&x)
7785            .unwrap()
7786            .grad_fn()
7787            .unwrap()
7788            .backward(&t(&[1.0f32; 16], &[1, 2, 8]))
7789            .unwrap();
7790        assert_close(
7791            grads[0].as_ref().unwrap().data().unwrap(),
7792            &[2.1, 2.1, 2.1, 2.1, 5.7, 5.7, 5.7, 5.7],
7793            1e-4,
7794        );
7795        assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 2, 3]);
7796        assert_close(
7797            grads[1].as_ref().unwrap().data().unwrap(),
7798            &[
7799                10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 26.0, 26.0, 26.0, 26.0, 26.0, 26.0,
7800            ],
7801            1e-4,
7802        );
7803        assert_close(
7804            grads[2].as_ref().unwrap().data().unwrap(),
7805            &[8.0, 8.0],
7806            1e-4,
7807        );
7808    }
7809
7810    /// ConvTranspose1d combo: groups=2, dilation=2, stride=2, padding=1,
7811    /// output_padding=1. in=4 out=2 k=2. Forward + all grads vs torch oracle.
7812    #[test]
7813    fn test_conv_transpose1d_combo_matches_torch() {
7814        let weight: Vec<f32> = (1..=8).map(|i| i as f32 * 0.1).collect(); // [4,1,2]
7815        let bias = [0.5f32, -0.5];
7816        let ct = ct1d_full_fixed(4, 2, 2, 2, 1, 1, 2, 2, &weight, Some(&bias));
7817        let x = leaf(&(1..=12).map(|i| i as f32).collect::<Vec<_>>(), &[1, 4, 3]);
7818        let y = ct.forward(&x).unwrap();
7819        assert_eq!(y.shape(), &[1, 2, 6]);
7820        assert_close(
7821            y.data().unwrap(),
7822            &[
7823                0.5, 4.0, 0.5, 5.0, 0.5, 3.5, -0.5, 23.4, -0.5, 26.0, -0.5, 14.5,
7824            ],
7825            1e-3,
7826        );
7827        let grads = ct
7828            .forward(&x)
7829            .unwrap()
7830            .grad_fn()
7831            .unwrap()
7832            .backward(&t(&[1.0f32; 12], &[1, 2, 6]))
7833            .unwrap();
7834        assert_close(
7835            grads[0].as_ref().unwrap().data().unwrap(),
7836            &[0.2, 0.3, 0.3, 0.4, 0.7, 0.7, 0.6, 1.1, 1.1, 0.8, 1.5, 1.5],
7837            1e-4,
7838        );
7839        assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 1, 2]);
7840        assert_close(
7841            grads[1].as_ref().unwrap().data().unwrap(),
7842            &[5.0, 6.0, 11.0, 15.0, 17.0, 24.0, 23.0, 33.0],
7843            1e-4,
7844        );
7845        assert_close(
7846            grads[2].as_ref().unwrap().data().unwrap(),
7847            &[6.0, 6.0],
7848            1e-4,
7849        );
7850    }
7851
7852    /// Unbatched ConvTranspose1d input `(C, L)`. Forward emits rank-2 output;
7853    /// backward routes grad to the unbatched shape. Closes #1609.
7854    #[test]
7855    fn test_conv_transpose1d_unbatched_matches_torch() {
7856        let weight: Vec<f32> = (1..=12).map(|i| i as f32 * 0.1).collect(); // [2,3,2]
7857        let bias = [0.5f32, -0.5, 0.25];
7858        let ct = ct1d_full_fixed(2, 3, 2, 1, 0, 0, 1, 1, &weight, Some(&bias));
7859        let x = leaf(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), &[2, 3]); // (C=2, L=3)
7860        let y = ct.forward(&x).unwrap();
7861        assert_eq!(
7862            y.shape(),
7863            &[3, 4],
7864            "unbatched output must be rank 2 (C_out, L_out)"
7865        );
7866        assert_close(
7867            y.data().unwrap(),
7868            &[
7869                3.4, 7.6, 9.4, 5.9, 3.4, 9.0, 11.6, 6.7, 5.15, 12.15, 15.55, 9.25,
7870            ],
7871            1e-3,
7872        );
7873        // y.sum().backward() — grad_output is all-ones (matches the torch oracle);
7874        // full autograd so the grad flows back through SqueezeBackward to the
7875        // unbatched leaf, not just the inner ConvTranspose grad_fn.
7876        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
7877        ferrotorch_core::backward(&sum).unwrap();
7878        let gx = x.grad().unwrap().expect("input grad must be populated");
7879        assert_eq!(gx.shape(), &[2, 3], "grad must match unbatched input shape");
7880        assert_close(gx.data().unwrap(), &[2.1, 2.1, 2.1, 5.7, 5.7, 5.7], 1e-4);
7881    }
7882
7883    /// `ConvTranspose1d::new_full` rejects `groups` not dividing channels,
7884    /// matching `_ConvNd.__init__` ValueError (`conv.py:105-110`).
7885    #[test]
7886    fn test_conv_transpose1d_groups_must_divide_channels() {
7887        assert!(ConvTranspose1d::<f32>::new_full(3, 4, 2, 1, 0, 0, 1, 2, true).is_err());
7888        assert!(ConvTranspose1d::<f32>::new_full(4, 5, 2, 1, 0, 0, 1, 2, true).is_err());
7889    }
7890
7891    // ----- ConvTranspose2d -----
7892
7893    /// Build a grouped/dilated ConvTranspose2d via `new_full`, overwriting
7894    /// weight/bias. Weight must be `[in, out/groups, kH, kW]`.
7895    #[allow(clippy::too_many_arguments)]
7896    fn ct2d_full_fixed(
7897        in_c: usize,
7898        out_c: usize,
7899        k: (usize, usize),
7900        stride: (usize, usize),
7901        padding: (usize, usize),
7902        output_padding: (usize, usize),
7903        dilation: (usize, usize),
7904        groups: usize,
7905        weight: &[f32],
7906        bias: Option<&[f32]>,
7907    ) -> ConvTranspose2d<f32> {
7908        let mut ct = ConvTranspose2d::<f32>::new_full(
7909            in_c,
7910            out_c,
7911            k,
7912            stride,
7913            padding,
7914            output_padding,
7915            dilation,
7916            groups,
7917            bias.is_some(),
7918        )
7919        .unwrap();
7920        ct.weight = Parameter::from_slice(weight, &[in_c, out_c / groups, k.0, k.1]).unwrap();
7921        if let Some(bvals) = bias {
7922            ct.bias = Some(Parameter::from_slice(bvals, &[out_c]).unwrap());
7923        }
7924        ct
7925    }
7926
7927    /// Grouped ConvTranspose2d, groups=2. in=4 out=2 k=(2,2).
7928    #[test]
7929    fn test_conv_transpose2d_groups2_matches_torch() {
7930        let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.1).collect(); // [4,1,2,2]
7931        let bias = [0.5f32, -0.5];
7932        let ct = ct2d_full_fixed(
7933            4,
7934            2,
7935            (2, 2),
7936            (1, 1),
7937            (0, 0),
7938            (0, 0),
7939            (1, 1),
7940            2,
7941            &weight,
7942            Some(&bias),
7943        );
7944        let x = leaf(
7945            &(1..=16).map(|i| i as f32).collect::<Vec<_>>(),
7946            &[1, 4, 2, 2],
7947        );
7948        let y = ct.forward(&x).unwrap();
7949        assert_eq!(y.shape(), &[1, 2, 3, 3]);
7950        assert_close(
7951            y.data().unwrap(),
7952            &[
7953                3.1, 6.9, 4.5, 8.1, 18.9, 11.7, 6.3, 14.1, 8.5, 24.5, 53.9, 29.1, 58.3, 126.7,
7954                68.3, 34.1, 73.9, 39.5,
7955            ],
7956            1e-3,
7957        );
7958        let grads = ct
7959            .forward(&x)
7960            .unwrap()
7961            .grad_fn()
7962            .unwrap()
7963            .backward(&t(&[1.0f32; 18], &[1, 2, 3, 3]))
7964            .unwrap();
7965        assert_close(
7966            grads[0].as_ref().unwrap().data().unwrap(),
7967            &[
7968                1.0, 1.0, 1.0, 1.0, 2.6, 2.6, 2.6, 2.6, 4.2, 4.2, 4.2, 4.2, 5.8, 5.8, 5.8, 5.8,
7969            ],
7970            1e-4,
7971        );
7972        assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 1, 2, 2]);
7973        assert_close(
7974            grads[1].as_ref().unwrap().data().unwrap(),
7975            &[
7976                10.0, 10.0, 10.0, 10.0, 26.0, 26.0, 26.0, 26.0, 42.0, 42.0, 42.0, 42.0, 58.0, 58.0,
7977                58.0, 58.0,
7978            ],
7979            1e-4,
7980        );
7981        assert_close(
7982            grads[2].as_ref().unwrap().data().unwrap(),
7983            &[9.0, 9.0],
7984            1e-4,
7985        );
7986    }
7987
7988    /// Dilated ConvTranspose2d, dilation=2, no bias. in=1 out=1 k=(2,2).
7989    #[test]
7990    fn test_conv_transpose2d_dilation2_matches_torch() {
7991        let weight: Vec<f32> = (1..=4).map(|i| i as f32 * 0.1).collect(); // [1,1,2,2]
7992        let ct = ct2d_full_fixed(
7993            1,
7994            1,
7995            (2, 2),
7996            (1, 1),
7997            (0, 0),
7998            (0, 0),
7999            (2, 2),
8000            1,
8001            &weight,
8002            None,
8003        );
8004        let x = leaf(
8005            &(1..=9).map(|i| i as f32).collect::<Vec<_>>(),
8006            &[1, 1, 3, 3],
8007        );
8008        let y = ct.forward(&x).unwrap();
8009        assert_eq!(y.shape(), &[1, 1, 5, 5]);
8010        assert_close(
8011            y.data().unwrap(),
8012            &[
8013                0.1, 0.2, 0.5, 0.4, 0.6, 0.4, 0.5, 1.4, 1.0, 1.2, 1.0, 1.4, 3.6, 2.4, 3.0, 1.2,
8014                1.5, 3.4, 2.0, 2.4, 2.1, 2.4, 5.5, 3.2, 3.6,
8015            ],
8016            1e-3,
8017        );
8018        let grads = ct
8019            .forward(&x)
8020            .unwrap()
8021            .grad_fn()
8022            .unwrap()
8023            .backward(&t(&[1.0f32; 25], &[1, 1, 5, 5]))
8024            .unwrap();
8025        assert_close(
8026            grads[0].as_ref().unwrap().data().unwrap(),
8027            &[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
8028            1e-4,
8029        );
8030        assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 2, 2]);
8031        assert_close(
8032            grads[1].as_ref().unwrap().data().unwrap(),
8033            &[45.0, 45.0, 45.0, 45.0],
8034            1e-4,
8035        );
8036    }
8037
8038    /// ConvTranspose2d combo: groups=2, dilation=2, stride=2, padding=1,
8039    /// output_padding=1. in=2 out=2 k=(2,2).
8040    #[test]
8041    fn test_conv_transpose2d_combo_matches_torch() {
8042        let weight: Vec<f32> = (1..=8).map(|i| i as f32 * 0.1).collect(); // [2,1,2,2]
8043        let bias = [0.25f32, -0.25];
8044        let ct = ct2d_full_fixed(
8045            2,
8046            2,
8047            (2, 2),
8048            (2, 2),
8049            (1, 1),
8050            (1, 1),
8051            (2, 2),
8052            2,
8053            &weight,
8054            Some(&bias),
8055        );
8056        let x = leaf(
8057            &(1..=8).map(|i| i as f32).collect::<Vec<_>>(),
8058            &[1, 2, 2, 2],
8059        );
8060        let y = ct.forward(&x).unwrap();
8061        assert_eq!(y.shape(), &[1, 2, 4, 4]);
8062        assert_close(
8063            y.data().unwrap(),
8064            &[
8065                0.25, 0.25, 0.25, 0.25, 0.25, 2.25, 0.25, 1.85, 0.25, 0.25, 0.25, 0.25, 0.25, 2.65,
8066                0.25, 1.85, -0.25, -0.25, -0.25, -0.25, -0.25, 16.15, -0.25, 9.35, -0.25, -0.25,
8067                -0.25, -0.25, -0.25, 10.95, -0.25, 6.15,
8068            ],
8069            1e-3,
8070        );
8071        let grads = ct
8072            .forward(&x)
8073            .unwrap()
8074            .grad_fn()
8075            .unwrap()
8076            .backward(&t(&[1.0f32; 32], &[1, 2, 4, 4]))
8077            .unwrap();
8078        assert_close(
8079            grads[0].as_ref().unwrap().data().unwrap(),
8080            &[0.4, 0.7, 0.6, 1.0, 0.8, 1.5, 1.4, 2.6],
8081            1e-4,
8082        );
8083        assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 2, 2]);
8084        assert_close(
8085            grads[1].as_ref().unwrap().data().unwrap(),
8086            &[4.0, 7.0, 6.0, 10.0, 8.0, 15.0, 14.0, 26.0],
8087            1e-4,
8088        );
8089        assert_close(
8090            grads[2].as_ref().unwrap().data().unwrap(),
8091            &[16.0, 16.0],
8092            1e-4,
8093        );
8094    }
8095
8096    /// Unbatched ConvTranspose2d input `(C, H, W)`. Closes #1609.
8097    #[test]
8098    fn test_conv_transpose2d_unbatched_matches_torch() {
8099        let weight: Vec<f32> = (1..=8).map(|i| i as f32 * 0.1).collect(); // [2,1,2,2]
8100        let bias = [0.5f32];
8101        let ct = ct2d_full_fixed(
8102            2,
8103            1,
8104            (2, 2),
8105            (1, 1),
8106            (0, 0),
8107            (0, 0),
8108            (1, 1),
8109            1,
8110            &weight,
8111            Some(&bias),
8112        );
8113        let x = leaf(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), &[2, 2, 2]); // (C=2,H=2,W=2)
8114        let y = ct.forward(&x).unwrap();
8115        assert_eq!(y.shape(), &[1, 3, 3], "unbatched output must be rank 3");
8116        assert_close(
8117            y.data().unwrap(),
8118            &[3.1, 6.9, 4.5, 8.1, 18.9, 11.7, 6.3, 14.1, 8.5],
8119            1e-3,
8120        );
8121        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8122        ferrotorch_core::backward(&sum).unwrap();
8123        let gx = x.grad().unwrap().expect("input grad must be populated");
8124        assert_eq!(
8125            gx.shape(),
8126            &[2, 2, 2],
8127            "grad must match unbatched input shape"
8128        );
8129        assert_close(
8130            gx.data().unwrap(),
8131            &[1.0, 1.0, 1.0, 1.0, 2.6, 2.6, 2.6, 2.6],
8132            1e-4,
8133        );
8134    }
8135
8136    // ----- ConvTranspose3d -----
8137
8138    /// Build a grouped/dilated ConvTranspose3d via `new_full`, overwriting
8139    /// weight/bias. Weight must be `[in, out/groups, kD, kH, kW]`.
8140    #[allow(clippy::too_many_arguments)]
8141    fn ct3d_full_fixed(
8142        in_c: usize,
8143        out_c: usize,
8144        k: (usize, usize, usize),
8145        stride: (usize, usize, usize),
8146        padding: (usize, usize, usize),
8147        output_padding: (usize, usize, usize),
8148        dilation: (usize, usize, usize),
8149        groups: usize,
8150        weight: &[f32],
8151        bias: Option<&[f32]>,
8152    ) -> ConvTranspose3d<f32> {
8153        let mut ct = ConvTranspose3d::<f32>::new_full(
8154            in_c,
8155            out_c,
8156            k,
8157            stride,
8158            padding,
8159            output_padding,
8160            dilation,
8161            groups,
8162            bias.is_some(),
8163        )
8164        .unwrap();
8165        ct.weight = Parameter::from_slice(weight, &[in_c, out_c / groups, k.0, k.1, k.2]).unwrap();
8166        if let Some(bvals) = bias {
8167            ct.bias = Some(Parameter::from_slice(bvals, &[out_c]).unwrap());
8168        }
8169        ct
8170    }
8171
8172    /// Grouped ConvTranspose3d, groups=2, k=(1,1,1). in=2 out=2.
8173    #[test]
8174    fn test_conv_transpose3d_groups2_matches_torch() {
8175        let weight: Vec<f32> = (1..=2).map(|i| i as f32 * 0.5).collect(); // [2,1,1,1,1]
8176        let bias = [0.5f32, -0.5];
8177        let ct = ct3d_full_fixed(
8178            2,
8179            2,
8180            (1, 1, 1),
8181            (1, 1, 1),
8182            (0, 0, 0),
8183            (0, 0, 0),
8184            (1, 1, 1),
8185            2,
8186            &weight,
8187            Some(&bias),
8188        );
8189        let x = leaf(
8190            &(1..=16).map(|i| i as f32).collect::<Vec<_>>(),
8191            &[1, 2, 2, 2, 2],
8192        );
8193        let y = ct.forward(&x).unwrap();
8194        assert_eq!(y.shape(), &[1, 2, 2, 2, 2]);
8195        assert_close(
8196            y.data().unwrap(),
8197            &[
8198                1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 8.5, 9.5, 10.5, 11.5, 12.5, 13.5, 14.5,
8199                15.5,
8200            ],
8201            1e-3,
8202        );
8203        let grads = ct
8204            .forward(&x)
8205            .unwrap()
8206            .grad_fn()
8207            .unwrap()
8208            .backward(&t(&[1.0f32; 16], &[1, 2, 2, 2, 2]))
8209            .unwrap();
8210        assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 1, 1, 1]);
8211        assert_close(
8212            grads[1].as_ref().unwrap().data().unwrap(),
8213            &[36.0, 100.0],
8214            1e-4,
8215        );
8216        assert_close(
8217            grads[2].as_ref().unwrap().data().unwrap(),
8218            &[8.0, 8.0],
8219            1e-4,
8220        );
8221    }
8222
8223    /// Dilated ConvTranspose3d, dilation=2, no bias. in=1 out=1 k=(2,2,2).
8224    #[test]
8225    fn test_conv_transpose3d_dilation2_matches_torch() {
8226        let weight: Vec<f32> = (1..=8).map(|i| i as f32 * 0.1).collect(); // [1,1,2,2,2]
8227        let ct = ct3d_full_fixed(
8228            1,
8229            1,
8230            (2, 2, 2),
8231            (1, 1, 1),
8232            (0, 0, 0),
8233            (0, 0, 0),
8234            (2, 2, 2),
8235            1,
8236            &weight,
8237            None,
8238        );
8239        let x = leaf(
8240            &(1..=8).map(|i| i as f32).collect::<Vec<_>>(),
8241            &[1, 1, 2, 2, 2],
8242        );
8243        let y = ct.forward(&x).unwrap();
8244        assert_eq!(y.shape(), &[1, 1, 4, 4, 4]);
8245        // Spot-check a representative slab against the torch oracle.
8246        let yd = y.data().unwrap();
8247        assert_close(&yd[0..8], &[0.1, 0.2, 0.2, 0.4, 0.3, 0.4, 0.6, 0.8], 1e-3);
8248        assert_close(&yd[56..64], &[3.5, 4.2, 4.0, 4.8, 4.9, 5.6, 5.6, 6.4], 1e-3);
8249        let grads = ct
8250            .forward(&x)
8251            .unwrap()
8252            .grad_fn()
8253            .unwrap()
8254            .backward(&t(&[1.0f32; 64], &[1, 1, 4, 4, 4]))
8255            .unwrap();
8256        assert_close(
8257            grads[0].as_ref().unwrap().data().unwrap(),
8258            &[3.6, 3.6, 3.6, 3.6, 3.6, 3.6, 3.6, 3.6],
8259            1e-4,
8260        );
8261        assert_eq!(grads[1].as_ref().unwrap().shape(), &[1, 1, 2, 2, 2]);
8262        assert_close(
8263            grads[1].as_ref().unwrap().data().unwrap(),
8264            &[36.0, 36.0, 36.0, 36.0, 36.0, 36.0, 36.0, 36.0],
8265            1e-4,
8266        );
8267    }
8268
8269    /// ConvTranspose3d combo: groups=2, stride=2, output_padding=1. in=2 out=2
8270    /// k=(2,2,2). Forward + all grads vs torch oracle.
8271    #[test]
8272    fn test_conv_transpose3d_combo_matches_torch() {
8273        let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.05).collect(); // [2,1,2,2,2]
8274        let bias = [0.1f32, -0.1];
8275        let ct = ct3d_full_fixed(
8276            2,
8277            2,
8278            (2, 2, 2),
8279            (2, 2, 2),
8280            (0, 0, 0),
8281            (1, 1, 1),
8282            (1, 1, 1),
8283            2,
8284            &weight,
8285            Some(&bias),
8286        );
8287        let x = leaf(
8288            &(1..=2).map(|i| i as f32).collect::<Vec<_>>(),
8289            &[1, 2, 1, 1, 1],
8290        );
8291        let y = ct.forward(&x).unwrap();
8292        assert_eq!(y.shape(), &[1, 2, 3, 3, 3]);
8293        let yd = y.data().unwrap();
8294        // Spot-check the leading + trailing-channel slabs vs the torch oracle.
8295        assert_close(
8296            &yd[0..9],
8297            &[0.15, 0.2, 0.1, 0.25, 0.3, 0.1, 0.1, 0.1, 0.1],
8298            1e-3,
8299        );
8300        assert_close(
8301            &yd[27..36],
8302            &[0.8, 0.9, -0.1, 1.0, 1.1, -0.1, -0.1, -0.1, -0.1],
8303            1e-3,
8304        );
8305        let grads = ct
8306            .forward(&x)
8307            .unwrap()
8308            .grad_fn()
8309            .unwrap()
8310            .backward(&t(&[1.0f32; 54], &[1, 2, 3, 3, 3]))
8311            .unwrap();
8312        assert_close(
8313            grads[0].as_ref().unwrap().data().unwrap(),
8314            &[1.8, 5.0],
8315            1e-4,
8316        );
8317        assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 2, 2, 2]);
8318        assert_close(
8319            grads[1].as_ref().unwrap().data().unwrap(),
8320            &[
8321                1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
8322            ],
8323            1e-4,
8324        );
8325        assert_close(
8326            grads[2].as_ref().unwrap().data().unwrap(),
8327            &[27.0, 27.0],
8328            1e-4,
8329        );
8330    }
8331
8332    /// ConvTranspose3d dilated forward with `output_padding=1` AND a kernel dim
8333    /// whose `dilation*(k-1) < padding` (here `kw=1`, `dilation_w=2`,
8334    /// `padding_w=1` -> internal pad `eff_kw-1-pw = -1`, NEGATIVE). The prior
8335    /// `internal_pad = eff_k-1-padding` `usize` subtraction wrapped to
8336    /// `usize::MAX` in release builds, which made `im2col_3d_dilated`'s width
8337    /// bounds check reject every position -> ZERO scatter -> the output was the
8338    /// bias alone in the `output_padding`-extended trailing region (#1619:
8339    /// op_db conv_transpose3d sample 4/5, ferrotorch=bias vs torch=-94.2 at
8340    /// flat index 279). The fix crops the upsampled signal for the negative
8341    /// dims and zero-pads the rest, matching upstream's output-extent-bounded
8342    /// `col2vol` scatter (`aten/src/ATen/native/vol2col.h:80-106`). Oracle is
8343    /// live torch 2.11.0 `F.conv_transpose3d(stride=2, padding=1,
8344    /// output_padding=1, dilation=(2,3,2))`. Closes #1619.
8345    #[test]
8346    fn test_conv_transpose3d_dilated_output_padding_negative_internal_pad_matches_torch() {
8347        let weight: Vec<f32> = (1..=4).map(|i| i as f32 * 0.1).collect(); // [1,1,2,2,1]
8348        let bias = [0.5f32];
8349        let ct = ct3d_full_fixed(
8350            1,
8351            1,
8352            (2, 2, 1),
8353            (2, 2, 2),
8354            (1, 1, 1),
8355            (1, 1, 1),
8356            (2, 3, 2),
8357            1,
8358            &weight,
8359            Some(&bias),
8360        );
8361        let x = leaf(
8362            &(1..=8).map(|i| i as f32).collect::<Vec<_>>(),
8363            &[1, 1, 2, 2, 2],
8364        );
8365        let y = ct.forward(&x).unwrap();
8366        assert_eq!(y.shape(), &[1, 1, 4, 5, 2]);
8367        // Full output vs the live torch 2.11.0 oracle. The trailing-region
8368        // positions (indices 13, 15, 19, 33, 35, 39) are exactly the ones the
8369        // bug zeroed; index 39 (the trailing corner) must be 3.7, not the bias.
8370        let yd = y.data().unwrap();
8371        #[rustfmt::skip]
8372        let oracle = [
8373            0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 2.5, 0.5, 2.5,
8374            0.5, 0.5, 0.5, 3.7, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
8375            0.5, 2.9, 0.5, 2.9, 0.5, 0.5, 0.5, 3.7,
8376        ];
8377        assert_close(yd, &oracle, 1e-4);
8378        // Backward must also flow through the cropped path (torch oracle grads).
8379        let grads = ct
8380            .forward(&x)
8381            .unwrap()
8382            .grad_fn()
8383            .unwrap()
8384            .backward(&t(&[1.0f32; 40], &[1, 1, 4, 5, 2]))
8385            .unwrap();
8386        assert_close(
8387            grads[0].as_ref().unwrap().data().unwrap(),
8388            &[0.0, 0.4, 0.0, 0.7, 0.0, 0.6, 0.0, 1.0],
8389            1e-4,
8390        );
8391        assert_close(
8392            grads[1].as_ref().unwrap().data().unwrap(),
8393            &[8.0, 14.0, 12.0, 20.0],
8394            1e-4,
8395        );
8396        assert_close(grads[2].as_ref().unwrap().data().unwrap(), &[40.0], 1e-4);
8397    }
8398
8399    /// Unbatched ConvTranspose3d input `(C, D, H, W)`. Closes #1609.
8400    #[test]
8401    fn test_conv_transpose3d_unbatched_matches_torch() {
8402        let weight: Vec<f32> = (1..=2).map(|i| i as f32 * 0.5).collect(); // [2,1,1,1,1]
8403        let bias = [1.0f32];
8404        let ct = ct3d_full_fixed(
8405            2,
8406            1,
8407            (1, 1, 1),
8408            (1, 1, 1),
8409            (0, 0, 0),
8410            (0, 0, 0),
8411            (1, 1, 1),
8412            1,
8413            &weight,
8414            Some(&bias),
8415        );
8416        let x = leaf(
8417            &(1..=16).map(|i| i as f32).collect::<Vec<_>>(),
8418            &[2, 2, 2, 2],
8419        ); // (C=2,D=2,H=2,W=2)
8420        let y = ct.forward(&x).unwrap();
8421        assert_eq!(y.shape(), &[1, 2, 2, 2], "unbatched output must be rank 4");
8422        // torch oracle forward: w=[0.5,1.0] (out=C0*0.5 + ... groups=1):
8423        // y[c=0..,d,h,w] = 0.5*x[ch0] + 1.0*x[ch1]; with bias 1.0.
8424        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8425        ferrotorch_core::backward(&sum).unwrap();
8426        let gx = x.grad().unwrap().expect("input grad must be populated");
8427        assert_eq!(
8428            gx.shape(),
8429            &[2, 2, 2, 2],
8430            "grad must match unbatched input shape"
8431        );
8432        // grad_x = sum over out of weight = ch0: 0.5, ch1: 1.0 (1x1x1 kernel).
8433        assert_close(
8434            gx.data().unwrap(),
8435            &[
8436                0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
8437            ],
8438            1e-4,
8439        );
8440    }
8441
8442    /// `ConvTranspose3d::new_full` rejects `groups` not dividing channels.
8443    #[test]
8444    fn test_conv_transpose3d_groups_must_divide_channels() {
8445        assert!(
8446            ConvTranspose3d::<f32>::new_full(
8447                3,
8448                4,
8449                (1, 1, 1),
8450                (1, 1, 1),
8451                (0, 0, 0),
8452                (0, 0, 0),
8453                (1, 1, 1),
8454                2,
8455                true
8456            )
8457            .is_err()
8458        );
8459        assert!(
8460            ConvTranspose2d::<f32>::new_full(4, 5, (1, 1), (1, 1), (0, 0), (0, 0), (1, 1), 2, true)
8461                .is_err()
8462        );
8463    }
8464
8465    // -----------------------------------------------------------------------
8466    // padding_mode threading — closes #1443
8467    //
8468    // Conv1d / Conv3d honor reflect/replicate/circular padding_mode for both
8469    // forward AND backward; ConvTranspose{1,2,3}d reject non-zeros modes
8470    // (matching `_ConvTransposeNd.__init__` ValueError, conv.py:755-758).
8471    //
8472    // All expected values are derived from a live PyTorch 2.11 oracle
8473    // (R-CHAR-3): the exact `torch.nn.Conv{1,3}d(..., padding_mode=...)` forward
8474    // outputs and `x.grad` after `out.sum().backward()`, with the same
8475    // deterministic weights/inputs reproduced below. The oracle script is in
8476    // the #1443 commit body. No tautological self-reference.
8477    // -----------------------------------------------------------------------
8478
8479    /// Build a Conv1d with explicit weight/bias for deterministic oracle parity.
8480    fn conv1d_fixed(
8481        weight: &[f32],
8482        wshape: &[usize],
8483        bias: &[f32],
8484        kernel: usize,
8485        padding: usize,
8486        mode: crate::padding::PaddingMode,
8487    ) -> Conv1d<f32> {
8488        let w = Parameter::from_slice(weight, wshape).unwrap();
8489        let b = Parameter::from_slice(bias, &[wshape[0]]).unwrap();
8490        Conv1d {
8491            weight: w,
8492            bias: Some(b),
8493            in_channels: wshape[1],
8494            out_channels: wshape[0],
8495            kernel_size: kernel,
8496            stride: 1,
8497            padding,
8498            dilation: 1,
8499            groups: 1,
8500            padding_mode: mode,
8501            string_padding: None,
8502            training: false,
8503        }
8504    }
8505
8506    /// Conv1d reflect: forward output matches torch oracle.
8507    /// torch: Conv1d(1,1,3,padding=1,padding_mode='reflect'), w=[1,2,3], b=0.5,
8508    /// x=[1,2,3,4,5] -> out=[10.5, 14.5, 20.5, 26.5, 26.5].
8509    #[test]
8510    fn test_conv1d_reflect_forward_matches_torch() {
8511        let conv = conv1d_fixed(
8512            &[1.0, 2.0, 3.0],
8513            &[1, 1, 3],
8514            &[0.5],
8515            3,
8516            1,
8517            crate::padding::PaddingMode::Reflect,
8518        );
8519        let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8520        let y = conv.forward(&x).unwrap();
8521        assert_eq!(y.shape(), &[1, 1, 5]);
8522        assert_close(y.data().unwrap(), &[10.5, 14.5, 20.5, 26.5, 26.5], 1e-4);
8523    }
8524
8525    /// Conv1d replicate: forward output matches torch oracle.
8526    /// torch out=[9.5, 14.5, 20.5, 26.5, 29.5].
8527    #[test]
8528    fn test_conv1d_replicate_forward_matches_torch() {
8529        let conv = conv1d_fixed(
8530            &[1.0, 2.0, 3.0],
8531            &[1, 1, 3],
8532            &[0.5],
8533            3,
8534            1,
8535            crate::padding::PaddingMode::Replicate,
8536        );
8537        let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8538        let y = conv.forward(&x).unwrap();
8539        assert_close(y.data().unwrap(), &[9.5, 14.5, 20.5, 26.5, 29.5], 1e-4);
8540    }
8541
8542    /// Conv1d circular: forward output matches torch oracle.
8543    /// torch out=[13.5, 14.5, 20.5, 26.5, 17.5].
8544    #[test]
8545    fn test_conv1d_circular_forward_matches_torch() {
8546        let conv = conv1d_fixed(
8547            &[1.0, 2.0, 3.0],
8548            &[1, 1, 3],
8549            &[0.5],
8550            3,
8551            1,
8552            crate::padding::PaddingMode::Circular,
8553        );
8554        let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8555        let y = conv.forward(&x).unwrap();
8556        assert_close(y.data().unwrap(), &[13.5, 14.5, 20.5, 26.5, 17.5], 1e-4);
8557    }
8558
8559    /// Conv1d reflect backward: input gradient flows through the non-zero pad
8560    /// (the #1550 regression class — a pad returning requires_grad=false would
8561    /// sever autograd and produce a None / zero grad here). torch grad_input
8562    /// for out.sum().backward() = [3, 7, 6, 9, 5].
8563    #[test]
8564    fn test_conv1d_reflect_backward_input_grad_matches_torch() {
8565        let conv = conv1d_fixed(
8566            &[1.0, 2.0, 3.0],
8567            &[1, 1, 3],
8568            &[0.5],
8569            3,
8570            1,
8571            crate::padding::PaddingMode::Reflect,
8572        );
8573        let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8574        let y = conv.forward(&x).unwrap();
8575        // grad_fn must be present — the autograd graph survives the pre-pad.
8576        assert!(
8577            y.grad_fn().is_some(),
8578            "Conv1d reflect output lost its grad_fn — pre-pad severed autograd (#1550 class)"
8579        );
8580        // `out.sum().backward()` — matches the torch oracle (grad_output = ones).
8581        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8582        ferrotorch_core::backward(&sum).unwrap();
8583        let xg = x
8584            .grad()
8585            .unwrap()
8586            .expect("input grad must be populated — pre-pad must be autograd-aware");
8587        assert_close(xg.data().unwrap(), &[3.0, 7.0, 6.0, 9.0, 5.0], 1e-4);
8588    }
8589
8590    /// Conv1d circular backward input grad matches torch: [6, 6, 6, 6, 6].
8591    #[test]
8592    fn test_conv1d_circular_backward_input_grad_matches_torch() {
8593        let conv = conv1d_fixed(
8594            &[1.0, 2.0, 3.0],
8595            &[1, 1, 3],
8596            &[0.5],
8597            3,
8598            1,
8599            crate::padding::PaddingMode::Circular,
8600        );
8601        let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
8602        let y = conv.forward(&x).unwrap();
8603        assert!(y.grad_fn().is_some());
8604        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8605        ferrotorch_core::backward(&sum).unwrap();
8606        let xg = x.grad().unwrap().expect("input grad must be populated");
8607        assert_close(xg.data().unwrap(), &[6.0, 6.0, 6.0, 6.0, 6.0], 1e-4);
8608    }
8609
8610    // -----------------------------------------------------------------------
8611    // Conv1d groups + dilation (closes #1600) — oracle: live torch 2.11.0
8612    // -----------------------------------------------------------------------
8613
8614    /// Build a grouped/dilated Conv1d through the production `new_full`
8615    /// constructor, then overwrite the weight/bias with deterministic
8616    /// caller-supplied tensors via `set_weight` / `set_data`. The weight must
8617    /// be `[out, in/groups, k]` (the grouped-conv layout, `conv.py:171`).
8618    fn conv1d_full_fixed(
8619        in_c: usize,
8620        out_c: usize,
8621        k: usize,
8622        dilation: usize,
8623        groups: usize,
8624        weight: &[f32],
8625        bias: Option<&[f32]>,
8626    ) -> Conv1d<f32> {
8627        let mut conv =
8628            Conv1d::<f32>::new_full(in_c, out_c, k, 1, 0, dilation, groups, bias.is_some())
8629                .unwrap();
8630        // Overwrite the Kaiming-initialised weight with the deterministic
8631        // tensor (direct field write — tests live in the same module).
8632        conv.weight = Parameter::from_slice(weight, &[out_c, in_c / groups, k]).unwrap();
8633        if let Some(bvals) = bias {
8634            conv.bias = Some(Parameter::from_slice(bvals, &[out_c]).unwrap());
8635        }
8636        conv
8637    }
8638
8639    /// Grouped Conv1d, groups=2. Forward + grad_x + grad_w + grad_b all match
8640    /// the live torch 2.11 oracle (`F.conv1d(x, w, b, groups=2)`,
8641    /// `out.sum().backward()`). in=4 out=4 k=2 groups=2.
8642    #[test]
8643    fn test_conv1d_groups2_forward_and_backward_matches_torch() {
8644        // weight [4, 2, 2] = arange(1..=16) * 0.1; bias [0.5,-0.5,0.25,-0.25].
8645        let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.1).collect();
8646        let bias = [0.5f32, -0.5, 0.25, -0.25];
8647        let conv = conv1d_full_fixed(4, 4, 2, 1, 2, &weight, Some(&bias));
8648
8649        // x [1, 4, 5] = arange(1..=20).
8650        let x_data: Vec<f32> = (1..=20).map(|i| i as f32).collect();
8651        let x = leaf(&x_data, &[1, 4, 5]);
8652        let y = conv.forward(&x).unwrap();
8653        assert_eq!(y.shape(), &[1, 4, 4]);
8654        // torch A_fwd:
8655        assert_close(
8656            y.data().unwrap(),
8657            &[
8658                5.6, 6.6, 7.6, 8.6, 11.0, 13.6, 16.2, 18.8, 60.15, 64.35, 68.55, 72.75, 82.05,
8659                87.85, 93.65, 99.45,
8660            ],
8661            1e-3,
8662        );
8663
8664        // out.sum().backward() => grad_output = ones.
8665        let grad_output = t(&[1.0f32; 16], &[1, 4, 4]);
8666        let grads = conv
8667            .forward(&x)
8668            .unwrap()
8669            .grad_fn()
8670            .unwrap()
8671            .backward(&grad_output)
8672            .unwrap();
8673        // grad_input (torch A_gx):
8674        assert_close(
8675            grads[0].as_ref().unwrap().data().unwrap(),
8676            &[
8677                0.6, 1.4, 1.4, 1.4, 0.8, 1.0, 2.2, 2.2, 2.2, 1.2, 2.2, 4.6, 4.6, 4.6, 2.4, 2.6,
8678                5.4, 5.4, 5.4, 2.8,
8679            ],
8680            1e-4,
8681        );
8682        // grad_weight (torch A_gw) — shape [4, 2, 2]:
8683        assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 2, 2]);
8684        assert_close(
8685            grads[1].as_ref().unwrap().data().unwrap(),
8686            &[
8687                10.0, 14.0, 30.0, 34.0, 10.0, 14.0, 30.0, 34.0, 50.0, 54.0, 70.0, 74.0, 50.0, 54.0,
8688                70.0, 74.0,
8689            ],
8690            1e-4,
8691        );
8692        // grad_bias (torch A_gb):
8693        assert_close(
8694            grads[2].as_ref().unwrap().data().unwrap(),
8695            &[4.0, 4.0, 4.0, 4.0],
8696            1e-4,
8697        );
8698    }
8699
8700    /// Depthwise Conv1d, groups=3 (groups == in_channels), no bias. Forward +
8701    /// grad_x + grad_w match the live torch 2.11 oracle. in=3 out=3 k=2.
8702    #[test]
8703    fn test_conv1d_groups3_depthwise_forward_and_backward_matches_torch() {
8704        // weight [3, 1, 2] = arange(1..=6) * 0.5.
8705        let weight: Vec<f32> = (1..=6).map(|i| i as f32 * 0.5).collect();
8706        let conv = conv1d_full_fixed(3, 3, 2, 1, 3, &weight, None);
8707
8708        // x [1, 3, 6] = arange(1..=18).
8709        let x_data: Vec<f32> = (1..=18).map(|i| i as f32).collect();
8710        let x = leaf(&x_data, &[1, 3, 6]);
8711        let y = conv.forward(&x).unwrap();
8712        assert_eq!(y.shape(), &[1, 3, 5]);
8713        // torch B_fwd:
8714        assert_close(
8715            y.data().unwrap(),
8716            &[
8717                2.5, 4.0, 5.5, 7.0, 8.5, 26.5, 30.0, 33.5, 37.0, 40.5, 74.5, 80.0, 85.5, 91.0, 96.5,
8718            ],
8719            1e-3,
8720        );
8721
8722        let grad_output = t(&[1.0f32; 15], &[1, 3, 5]);
8723        let grads = conv
8724            .forward(&x)
8725            .unwrap()
8726            .grad_fn()
8727            .unwrap()
8728            .backward(&grad_output)
8729            .unwrap();
8730        // grad_input (torch B_gx):
8731        assert_close(
8732            grads[0].as_ref().unwrap().data().unwrap(),
8733            &[
8734                0.5, 1.5, 1.5, 1.5, 1.5, 1.0, 1.5, 3.5, 3.5, 3.5, 3.5, 2.0, 2.5, 5.5, 5.5, 5.5,
8735                5.5, 3.0,
8736            ],
8737            1e-4,
8738        );
8739        // grad_weight (torch B_gw) — shape [3, 1, 2]:
8740        assert_eq!(grads[1].as_ref().unwrap().shape(), &[3, 1, 2]);
8741        assert_close(
8742            grads[1].as_ref().unwrap().data().unwrap(),
8743            &[15.0, 20.0, 45.0, 50.0, 75.0, 80.0],
8744            1e-4,
8745        );
8746    }
8747
8748    /// Dilated Conv1d, dilation=2, groups=1. Forward + grad_x + grad_w +
8749    /// grad_b match the live torch 2.11 oracle. in=2 out=2 k=3 dilation=2 =>
8750    /// eff_k=5, L=7 -> L_out=3.
8751    #[test]
8752    fn test_conv1d_dilation2_forward_and_backward_matches_torch() {
8753        // weight [2, 2, 3] = arange(1..=12) * 0.1; bias [1.0, -1.0].
8754        let weight: Vec<f32> = (1..=12).map(|i| i as f32 * 0.1).collect();
8755        let bias = [1.0f32, -1.0];
8756        let conv = conv1d_full_fixed(2, 2, 3, 2, 1, &weight, Some(&bias));
8757
8758        // x [1, 2, 7] = arange(1..=14).
8759        let x_data: Vec<f32> = (1..=14).map(|i| i as f32).collect();
8760        let x = leaf(&x_data, &[1, 2, 7]);
8761        let y = conv.forward(&x).unwrap();
8762        assert_eq!(y.shape(), &[1, 2, 3]);
8763        // torch C_fwd:
8764        assert_close(
8765            y.data().unwrap(),
8766            &[18.6, 20.7, 22.8, 40.0, 45.7, 51.4],
8767            1e-3,
8768        );
8769
8770        let grad_output = t(&[1.0f32; 6], &[1, 2, 3]);
8771        let grads = conv
8772            .forward(&x)
8773            .unwrap()
8774            .grad_fn()
8775            .unwrap()
8776            .backward(&grad_output)
8777            .unwrap();
8778        // grad_input (torch C_gx):
8779        assert_close(
8780            grads[0].as_ref().unwrap().data().unwrap(),
8781            &[
8782                0.8, 0.8, 1.8, 1.0, 2.2, 1.2, 1.2, 1.4, 1.4, 3.0, 1.6, 3.4, 1.8, 1.8,
8783            ],
8784            1e-4,
8785        );
8786        // grad_weight (torch C_gw) — shape [2, 2, 3]:
8787        assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 2, 3]);
8788        assert_close(
8789            grads[1].as_ref().unwrap().data().unwrap(),
8790            &[
8791                6.0, 12.0, 18.0, 27.0, 33.0, 39.0, 6.0, 12.0, 18.0, 27.0, 33.0, 39.0,
8792            ],
8793            1e-4,
8794        );
8795        // grad_bias (torch C_gb):
8796        assert_close(
8797            grads[2].as_ref().unwrap().data().unwrap(),
8798            &[3.0, 3.0],
8799            1e-4,
8800        );
8801    }
8802
8803    /// `Conv1d::new_full` rejects `groups` that does not divide channels,
8804    /// matching `torch.nn.Conv1d`'s `ValueError` (`conv.py:107-110`).
8805    #[test]
8806    fn test_conv1d_groups_must_divide_channels() {
8807        // in_channels=3 not divisible by groups=2.
8808        assert!(Conv1d::<f32>::new_full(3, 4, 2, 1, 0, 1, 2, true).is_err());
8809        // out_channels=5 not divisible by groups=2 (in divisible).
8810        assert!(Conv1d::<f32>::new_full(4, 5, 2, 1, 0, 1, 2, true).is_err());
8811        // zero groups rejected.
8812        assert!(Conv1d::<f32>::new_full(4, 4, 2, 1, 0, 1, 0, true).is_err());
8813        // zero dilation rejected.
8814        assert!(Conv1d::<f32>::new_full(4, 4, 2, 1, 0, 0, 2, true).is_err());
8815        // valid grouped config accepted.
8816        assert!(Conv1d::<f32>::new_full(4, 4, 2, 1, 0, 1, 2, true).is_ok());
8817    }
8818
8819    /// Build a Conv3d with explicit weight/bias for deterministic oracle parity.
8820    fn conv3d_fixed(
8821        weight: &[f32],
8822        wshape: &[usize],
8823        bias: &[f32],
8824        kernel: (usize, usize, usize),
8825        padding: (usize, usize, usize),
8826        mode: crate::padding::PaddingMode,
8827    ) -> Conv3d<f32> {
8828        let w = Parameter::from_slice(weight, wshape).unwrap();
8829        let b = Parameter::from_slice(bias, &[wshape[0]]).unwrap();
8830        Conv3d {
8831            weight: w,
8832            bias: Some(b),
8833            in_channels: wshape[1],
8834            out_channels: wshape[0],
8835            kernel_size: kernel,
8836            stride: (1, 1, 1),
8837            padding,
8838            dilation: (1, 1, 1),
8839            groups: 1,
8840            padding_mode: mode,
8841            string_padding: None,
8842            training: false,
8843        }
8844    }
8845
8846    /// Conv3d replicate forward matches torch oracle.
8847    /// torch: Conv3d(1,1,(2,2,2),padding=(1,1,1),padding_mode='replicate'),
8848    /// w=arange(1..=8), b=0, x=arange(1..=8) -> 27-element [1,1,3,3,3] output.
8849    #[test]
8850    fn test_conv3d_replicate_forward_matches_torch() {
8851        let w: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8852        let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8853        let conv = conv3d_fixed(
8854            &w,
8855            &[1, 1, 2, 2, 2],
8856            &[0.0],
8857            (2, 2, 2),
8858            (1, 1, 1),
8859            crate::padding::PaddingMode::Replicate,
8860        );
8861        let x = t(&x_data, &[1, 1, 2, 2, 2]);
8862        let y = conv.forward(&x).unwrap();
8863        assert_eq!(y.shape(), &[1, 1, 3, 3, 3]);
8864        let expected = [
8865            36.0, 56.0, 72.0, 80.0, 100.0, 116.0, 108.0, 128.0, 144.0, 140.0, 160.0, 176.0, 184.0,
8866            204.0, 220.0, 212.0, 232.0, 248.0, 180.0, 200.0, 216.0, 224.0, 244.0, 260.0, 252.0,
8867            272.0, 288.0,
8868        ];
8869        assert_close(y.data().unwrap(), &expected, 1e-3);
8870    }
8871
8872    /// Conv3d reflect forward matches torch oracle (same fixture, reflect mode).
8873    #[test]
8874    fn test_conv3d_reflect_forward_matches_torch() {
8875        let w: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8876        let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8877        let conv = conv3d_fixed(
8878            &w,
8879            &[1, 1, 2, 2, 2],
8880            &[0.0],
8881            (2, 2, 2),
8882            (1, 1, 1),
8883            crate::padding::PaddingMode::Reflect,
8884        );
8885        let x = t(&x_data, &[1, 1, 2, 2, 2]);
8886        let y = conv.forward(&x).unwrap();
8887        let expected = [
8888            120.0, 124.0, 120.0, 136.0, 140.0, 136.0, 120.0, 124.0, 120.0, 184.0, 188.0, 184.0,
8889            200.0, 204.0, 200.0, 184.0, 188.0, 184.0, 120.0, 124.0, 120.0, 136.0, 140.0, 136.0,
8890            120.0, 124.0, 120.0,
8891        ];
8892        assert_close(y.data().unwrap(), &expected, 1e-3);
8893    }
8894
8895    /// Conv3d circular forward matches torch oracle (discriminating asymmetric
8896    /// fixture: single-tap kernel + non-symmetric input so circular != reflect).
8897    /// torch: w[0,0,0,0,0]=1 else 0, x=[[1,2],[3,4]],[[5,6],[7,8]].
8898    #[test]
8899    fn test_conv3d_circular_forward_matches_torch() {
8900        let mut w = vec![0.0f32; 8];
8901        w[0] = 1.0;
8902        let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8903        let conv = conv3d_fixed(
8904            &w,
8905            &[1, 1, 2, 2, 2],
8906            &[0.0],
8907            (2, 2, 2),
8908            (1, 1, 1),
8909            crate::padding::PaddingMode::Circular,
8910        );
8911        let x = t(&x_data, &[1, 1, 2, 2, 2]);
8912        let y = conv.forward(&x).unwrap();
8913        let expected = [
8914            8.0, 7.0, 8.0, 6.0, 5.0, 6.0, 8.0, 7.0, 8.0, 4.0, 3.0, 4.0, 2.0, 1.0, 2.0, 4.0, 3.0,
8915            4.0, 8.0, 7.0, 8.0, 6.0, 5.0, 6.0, 8.0, 7.0, 8.0,
8916        ];
8917        assert_close(y.data().unwrap(), &expected, 1e-3);
8918    }
8919
8920    /// Conv3d replicate backward: input gradient flows through the non-zero pad
8921    /// (the #1550 regression class). torch grad_input for out.sum().backward()
8922    /// = [90, 99, 108, 117, 126, 135, 144, 153].
8923    #[test]
8924    fn test_conv3d_replicate_backward_input_grad_matches_torch() {
8925        let w: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8926        let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
8927        let conv = conv3d_fixed(
8928            &w,
8929            &[1, 1, 2, 2, 2],
8930            &[0.0],
8931            (2, 2, 2),
8932            (1, 1, 1),
8933            crate::padding::PaddingMode::Replicate,
8934        );
8935        let x = leaf(&x_data, &[1, 1, 2, 2, 2]);
8936        let y = conv.forward(&x).unwrap();
8937        assert!(
8938            y.grad_fn().is_some(),
8939            "Conv3d replicate output lost its grad_fn — pre-pad severed autograd (#1550 class)"
8940        );
8941        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
8942        ferrotorch_core::backward(&sum).unwrap();
8943        let xg = x.grad().unwrap().expect("input grad must be populated");
8944        assert_close(
8945            xg.data().unwrap(),
8946            &[90.0, 99.0, 108.0, 117.0, 126.0, 135.0, 144.0, 153.0],
8947            1e-3,
8948        );
8949    }
8950
8951    // -----------------------------------------------------------------------
8952    // Conv3d groups + dilation (closes #1601) — oracle: live torch 2.11.0
8953    // -----------------------------------------------------------------------
8954
8955    /// Grouped + dilated Conv3d, groups=2, dilation=2. Forward + grad_x +
8956    /// grad_w + grad_b match the live torch 2.11 oracle. in=2 out=2
8957    /// k=(2,2,2) groups=2 dilation=(2,2,2) over a 4x4x4 volume => eff_k=3,
8958    /// out spatial = 2x2x2.
8959    #[test]
8960    fn test_conv3d_groups2_dilation2_forward_and_backward_matches_torch() {
8961        // weight [2, 1, 2, 2, 2] = arange(1..=16) * 0.01; bias [0.1, -0.1].
8962        let weight: Vec<f32> = (1..=16).map(|i| i as f32 * 0.01).collect();
8963        let bias = [0.1f32, -0.1];
8964        let mut conv =
8965            Conv3d::<f32>::new_full(2, 2, (2, 2, 2), (1, 1, 1), (0, 0, 0), (2, 2, 2), 2, true)
8966                .unwrap();
8967        conv.weight = Parameter::from_slice(&weight, &[2, 1, 2, 2, 2]).unwrap();
8968        conv.bias = Some(Parameter::from_slice(&bias, &[2]).unwrap());
8969
8970        // x [1, 2, 4, 4, 4] = arange(1..=128).
8971        let x_data: Vec<f32> = (1..=128).map(|i| i as f32).collect();
8972        let x = leaf(&x_data, &[1, 2, 4, 4, 4]);
8973        let y = conv.forward(&x).unwrap();
8974        assert_eq!(y.shape(), &[1, 2, 2, 2, 2]);
8975        // torch D_fwd:
8976        assert_close(
8977            y.data().unwrap(),
8978            &[
8979                10.94, 11.3, 12.38, 12.74, 16.7, 17.06, 18.14, 18.5, 88.82, 89.82, 92.82, 93.82,
8980                104.82, 105.82, 108.82, 109.82,
8981            ],
8982            1e-3,
8983        );
8984
8985        let grad_output = t(&[1.0f32; 16], &[1, 2, 2, 2, 2]);
8986        let grads = conv
8987            .forward(&x)
8988            .unwrap()
8989            .grad_fn()
8990            .unwrap()
8991            .backward(&grad_output)
8992            .unwrap();
8993        // grad_input (torch D_gx) — full 128 elements:
8994        #[rustfmt::skip]
8995        let d_gx: [f32; 128] = [
8996            0.01, 0.01, 0.02, 0.02, 0.01, 0.01, 0.02, 0.02, 0.03, 0.03, 0.04, 0.04, 0.03, 0.03,
8997            0.04, 0.04, 0.01, 0.01, 0.02, 0.02, 0.01, 0.01, 0.02, 0.02, 0.03, 0.03, 0.04, 0.04,
8998            0.03, 0.03, 0.04, 0.04, 0.05, 0.05, 0.06, 0.06, 0.05, 0.05, 0.06, 0.06, 0.07, 0.07,
8999            0.08, 0.08, 0.07, 0.07, 0.08, 0.08, 0.05, 0.05, 0.06, 0.06, 0.05, 0.05, 0.06, 0.06,
9000            0.07, 0.07, 0.08, 0.08, 0.07, 0.07, 0.08, 0.08, 0.09, 0.09, 0.1, 0.1, 0.09, 0.09, 0.1,
9001            0.1, 0.11, 0.11, 0.12, 0.12, 0.11, 0.11, 0.12, 0.12, 0.09, 0.09, 0.1, 0.1, 0.09, 0.09,
9002            0.1, 0.1, 0.11, 0.11, 0.12, 0.12, 0.11, 0.11, 0.12, 0.12, 0.13, 0.13, 0.14, 0.14, 0.13,
9003            0.13, 0.14, 0.14, 0.15, 0.15, 0.16, 0.16, 0.15, 0.15, 0.16, 0.16, 0.13, 0.13, 0.14,
9004            0.14, 0.13, 0.13, 0.14, 0.14, 0.15, 0.15, 0.16, 0.16, 0.15, 0.15, 0.16, 0.16,
9005        ];
9006        assert_close(grads[0].as_ref().unwrap().data().unwrap(), &d_gx, 1e-4);
9007        // grad_weight (torch D_gw) — shape [2, 1, 2, 2, 2]:
9008        assert_eq!(grads[1].as_ref().unwrap().shape(), &[2, 1, 2, 2, 2]);
9009        assert_close(
9010            grads[1].as_ref().unwrap().data().unwrap(),
9011            &[
9012                92.0, 108.0, 156.0, 172.0, 348.0, 364.0, 412.0, 428.0, 604.0, 620.0, 668.0, 684.0,
9013                860.0, 876.0, 924.0, 940.0,
9014            ],
9015            1e-3,
9016        );
9017        // grad_bias (torch D_gb):
9018        assert_close(
9019            grads[2].as_ref().unwrap().data().unwrap(),
9020            &[8.0, 8.0],
9021            1e-4,
9022        );
9023    }
9024
9025    /// Grouped Conv3d (groups=2, dilation=1) sanity: a 1x1x1 grouped conv is
9026    /// a per-group channel mix. Forward + grad_x + grad_w match the live
9027    /// torch 2.11 oracle. in=2 out=4 k=(1,1,1) groups=2.
9028    #[test]
9029    fn test_conv3d_groups2_forward_and_backward_matches_torch() {
9030        // weight [4, 1, 1, 1, 1] = [1, 2, 3, 4], no bias.
9031        let weight = [1.0f32, 2.0, 3.0, 4.0];
9032        let mut conv =
9033            Conv3d::<f32>::new_full(2, 4, (1, 1, 1), (1, 1, 1), (0, 0, 0), (1, 1, 1), 2, false)
9034                .unwrap();
9035        conv.weight = Parameter::from_slice(&weight, &[4, 1, 1, 1, 1]).unwrap();
9036
9037        // x [1, 2, 2, 2, 2] = arange(1..=16).
9038        let x_data: Vec<f32> = (1..=16).map(|i| i as f32).collect();
9039        let x = leaf(&x_data, &[1, 2, 2, 2, 2]);
9040        let y = conv.forward(&x).unwrap();
9041        assert_eq!(y.shape(), &[1, 4, 2, 2, 2]);
9042        // torch E_fwd:
9043        assert_close(
9044            y.data().unwrap(),
9045            &[
9046                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0,
9047                27.0, 30.0, 33.0, 36.0, 39.0, 42.0, 45.0, 48.0, 36.0, 40.0, 44.0, 48.0, 52.0, 56.0,
9048                60.0, 64.0,
9049            ],
9050            1e-3,
9051        );
9052
9053        let grad_output = t(&[1.0f32; 32], &[1, 4, 2, 2, 2]);
9054        let grads = conv
9055            .forward(&x)
9056            .unwrap()
9057            .grad_fn()
9058            .unwrap()
9059            .backward(&grad_output)
9060            .unwrap();
9061        // grad_input (torch E_gx):
9062        assert_close(
9063            grads[0].as_ref().unwrap().data().unwrap(),
9064            &[
9065                3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0,
9066            ],
9067            1e-4,
9068        );
9069        // grad_weight (torch E_gw) — shape [4, 1, 1, 1, 1]:
9070        assert_eq!(grads[1].as_ref().unwrap().shape(), &[4, 1, 1, 1, 1]);
9071        assert_close(
9072            grads[1].as_ref().unwrap().data().unwrap(),
9073            &[36.0, 36.0, 100.0, 100.0],
9074            1e-4,
9075        );
9076    }
9077
9078    /// `Conv3d::new_full` rejects `groups` that does not divide channels,
9079    /// matching `torch.nn.Conv3d`'s `ValueError` (`conv.py:107-110`).
9080    #[test]
9081    fn test_conv3d_groups_must_divide_channels() {
9082        // in_channels=3 not divisible by groups=2.
9083        assert!(
9084            Conv3d::<f32>::new_full(3, 4, (2, 2, 2), (1, 1, 1), (0, 0, 0), (1, 1, 1), 2, true)
9085                .is_err()
9086        );
9087        // out_channels=5 not divisible by groups=2.
9088        assert!(
9089            Conv3d::<f32>::new_full(4, 5, (2, 2, 2), (1, 1, 1), (0, 0, 0), (1, 1, 1), 2, true)
9090                .is_err()
9091        );
9092        // zero groups rejected.
9093        assert!(
9094            Conv3d::<f32>::new_full(4, 4, (2, 2, 2), (1, 1, 1), (0, 0, 0), (1, 1, 1), 0, true)
9095                .is_err()
9096        );
9097        // zero dilation rejected.
9098        assert!(
9099            Conv3d::<f32>::new_full(4, 4, (2, 2, 2), (1, 1, 1), (0, 0, 0), (0, 1, 1), 2, true)
9100                .is_err()
9101        );
9102        // valid grouped+dilated config accepted.
9103        assert!(
9104            Conv3d::<f32>::new_full(4, 4, (2, 2, 2), (1, 1, 1), (0, 0, 0), (2, 2, 2), 2, true)
9105                .is_ok()
9106        );
9107    }
9108
9109    /// Conv1d with padding=0 ignores padding_mode (no pre-pad), matching torch
9110    /// (the `self.padding != 0` short-circuit in the forward).
9111    #[test]
9112    fn test_conv1d_reflect_zero_padding_is_noop() {
9113        let conv = conv1d_fixed(
9114            &[1.0, 2.0, 3.0],
9115            &[1, 1, 3],
9116            &[0.0],
9117            3,
9118            0,
9119            crate::padding::PaddingMode::Reflect,
9120        );
9121        let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
9122        let y = conv.forward(&x).unwrap();
9123        // padding=0 -> output length 3, plain conv: [1*1+2*2+3*3, ...]
9124        assert_eq!(y.shape(), &[1, 1, 3]);
9125        assert_close(y.data().unwrap(), &[14.0, 20.0, 26.0], 1e-4);
9126    }
9127
9128    // --- ConvTranspose: non-zeros padding_mode rejected (conv.py:755-758) ---
9129
9130    #[test]
9131    fn test_conv_transpose1d_reflect_padding_mode_rejected() {
9132        let conv = ConvTranspose1d::<f32>::new(2, 2, 3, 1, 0, 0, false).unwrap();
9133        let err = conv
9134            .with_padding_mode(crate::padding::PaddingMode::Reflect)
9135            .unwrap_err();
9136        // Message matches torch exactly:
9137        // 'Only "zeros" padding mode is supported for ConvTranspose1d'.
9138        let msg = format!("{err}");
9139        assert!(
9140            msg.contains("Only \"zeros\" padding mode is supported for ConvTranspose1d"),
9141            "got: {msg}"
9142        );
9143    }
9144
9145    #[test]
9146    fn test_conv_transpose2d_replicate_padding_mode_rejected() {
9147        let conv =
9148            ConvTranspose2d::<f32>::new(2, 2, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
9149        let err = conv
9150            .with_padding_mode(crate::padding::PaddingMode::Replicate)
9151            .unwrap_err();
9152        let msg = format!("{err}");
9153        assert!(
9154            msg.contains("Only \"zeros\" padding mode is supported for ConvTranspose2d"),
9155            "got: {msg}"
9156        );
9157    }
9158
9159    #[test]
9160    fn test_conv_transpose3d_circular_padding_mode_rejected() {
9161        let conv =
9162            ConvTranspose3d::<f32>::new(2, 2, (3, 3, 3), (1, 1, 1), (0, 0, 0), (0, 0, 0), false)
9163                .unwrap();
9164        let err = conv
9165            .with_padding_mode(crate::padding::PaddingMode::Circular)
9166            .unwrap_err();
9167        let msg = format!("{err}");
9168        assert!(
9169            msg.contains("Only \"zeros\" padding mode is supported for ConvTranspose3d"),
9170            "got: {msg}"
9171        );
9172    }
9173
9174    /// ConvTranspose accepts the `Zeros` mode (the only valid one) unchanged.
9175    #[test]
9176    fn test_conv_transpose2d_zeros_padding_mode_accepted() {
9177        let conv =
9178            ConvTranspose2d::<f32>::new(2, 2, (3, 3), (1, 1), (0, 0), (0, 0), false).unwrap();
9179        assert!(
9180            conv.with_padding_mode(crate::padding::PaddingMode::Zeros)
9181                .is_ok()
9182        );
9183    }
9184
9185    // =======================================================================
9186    // String padding 'same' / 'valid'  — #1602
9187    // Oracle values are from live torch 2.11.0 (`F.conv{1,2,3}d(..,
9188    // padding="same"|"valid")` / `nn.Conv2d(.., padding="same")`), R-CHAR-3.
9189    // The asymmetric 'same' split is `left = total/2`, `right = total - left`
9190    // with `total = dilation*(kernel-1)` (`aten/src/ATen/native/Pool.h:91-107`,
9191    // `torch/nn/modules/conv.py:143-155`) — the END side gets the extra unit.
9192    // =======================================================================
9193
9194    /// Build a Conv1d with explicit weight/bias for deterministic oracle parity.
9195    fn conv1d_with_weight(weight: &[f32], wshape: &[usize], bias: f32) -> Conv1d<f32> {
9196        let mut c = Conv1d::<f32>::new(wshape[1], wshape[0], wshape[2], 1, 0, true).unwrap();
9197        // Direct field write (tests live in the same module), mirroring the
9198        // existing `conv1d_full_fixed` helper.
9199        c.weight = Parameter::from_slice(weight, wshape).unwrap();
9200        c.bias = Some(Parameter::from_slice(&[bias], &[wshape[0]]).unwrap());
9201        c
9202    }
9203
9204    /// Conv1d 'same', ODD kernel k=3 (symmetric pad 1,1).
9205    /// torch: F.conv1d([[[1,2,3,4,5]]], w=[1,2,3], b=0.5, padding="same")
9206    ///   -> [8.5, 14.5, 20.5, 26.5, 14.5]
9207    #[test]
9208    fn test_conv1d_same_odd_kernel_matches_torch() {
9209        let conv = conv1d_with_weight(&[1.0, 2.0, 3.0], &[1, 1, 3], 0.5)
9210            .with_string_padding(StringPadding::Same)
9211            .unwrap();
9212        let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
9213        let y = conv.forward(&x).unwrap();
9214        assert_eq!(y.shape(), &[1, 1, 5]);
9215        assert_close(y.data().unwrap(), &[8.5, 14.5, 20.5, 26.5, 14.5], 1e-4);
9216    }
9217
9218    /// Conv1d 'same', EVEN kernel k=4 — ASYMMETRIC pad (total=3 -> left=1,
9219    /// right=2; the END gets the extra unit). torch:
9220    ///   F.conv1d([[[1..6]]], w=[1,2,3,4], b=0, padding="same")
9221    ///   -> [20, 30, 40, 50, 32, 17]
9222    /// A symmetric (left=right) split would give a different sequence, so this
9223    /// test FAILS unless the asymmetric `right = total - total/2` is correct.
9224    #[test]
9225    fn test_conv1d_same_even_kernel_asymmetric_matches_torch() {
9226        let conv = conv1d_with_weight(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4], 0.0)
9227            .with_string_padding(StringPadding::Same)
9228            .unwrap();
9229        let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 1, 6]);
9230        let y = conv.forward(&x).unwrap();
9231        assert_eq!(y.shape(), &[1, 1, 6]);
9232        assert_close(
9233            y.data().unwrap(),
9234            &[20.0, 30.0, 40.0, 50.0, 32.0, 17.0],
9235            1e-4,
9236        );
9237    }
9238
9239    /// Conv1d 'same' backward: gradient flows through the autograd-aware
9240    /// asymmetric pre-pad back to the original input shape.
9241    /// torch grad of sum(F.conv1d([[[1,2,3,4,5]]], w=[1,2,3], b=0.5,
9242    ///   padding="same")) wrt x = [3, 6, 6, 6, 5].
9243    #[test]
9244    fn test_conv1d_same_odd_kernel_backward_matches_torch() {
9245        let conv = conv1d_with_weight(&[1.0, 2.0, 3.0], &[1, 1, 3], 0.5)
9246            .with_string_padding(StringPadding::Same)
9247            .unwrap();
9248        let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
9249        let y = conv.forward(&x).unwrap();
9250        // out.sum().backward() — grad_output is all-ones (matches the torch oracle).
9251        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9252        ferrotorch_core::backward(&sum).unwrap();
9253        let gx = x.grad().unwrap().expect("input grad must be populated");
9254        assert_eq!(gx.shape(), &[1, 1, 5]);
9255        assert_close(gx.data().unwrap(), &[3.0, 6.0, 6.0, 6.0, 5.0], 1e-4);
9256    }
9257
9258    /// Conv1d 'same' backward, EVEN kernel asymmetric. torch grad wrt x =
9259    ///   [3, 6, 10, 10, 10, 9].
9260    #[test]
9261    fn test_conv1d_same_even_kernel_backward_matches_torch() {
9262        let conv = conv1d_with_weight(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4], 0.0)
9263            .with_string_padding(StringPadding::Same)
9264            .unwrap();
9265        let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 1, 6]);
9266        let y = conv.forward(&x).unwrap();
9267        // out.sum().backward() — grad_output is all-ones (matches the torch oracle).
9268        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9269        ferrotorch_core::backward(&sum).unwrap();
9270        let gx = x.grad().unwrap().expect("input grad must be populated");
9271        assert_eq!(gx.shape(), &[1, 1, 6]);
9272        assert_close(gx.data().unwrap(), &[3.0, 6.0, 10.0, 10.0, 10.0, 9.0], 1e-4);
9273    }
9274
9275    /// Conv1d 'valid' == padding 0. torch:
9276    ///   F.conv1d([[[1,2,3,4,5]]], w=[1,1,1], b=0, padding="valid") -> [6,9,12]
9277    #[test]
9278    fn test_conv1d_valid_matches_torch() {
9279        let conv = conv1d_with_weight(&[1.0, 1.0, 1.0], &[1, 1, 3], 0.0)
9280            .with_string_padding(StringPadding::Valid)
9281            .unwrap();
9282        let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
9283        let y = conv.forward(&x).unwrap();
9284        assert_eq!(y.shape(), &[1, 1, 3]);
9285        assert_close(y.data().unwrap(), &[6.0, 9.0, 12.0], 1e-4);
9286    }
9287
9288    /// Conv2d 'same', ODD kernel 3x3 (symmetric pad). torch oracle from
9289    ///   F.conv2d(arange(1..17).view(1,1,4,4), arange(1..10).view(1,1,3,3),
9290    ///            b=0.5, padding="same").
9291    #[test]
9292    fn test_conv2d_same_odd_kernel_matches_torch() {
9293        let weight = Parameter::from_slice(
9294            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9295            &[1, 1, 3, 3],
9296        )
9297        .unwrap();
9298        let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), true).unwrap();
9299        conv.set_weight(weight).unwrap();
9300        conv.bias = Some(Parameter::from_slice(&[0.5], &[1]).unwrap());
9301        let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9302        let x = t(
9303            &(1..=16).map(|v| v as f32).collect::<Vec<_>>(),
9304            &[1, 1, 4, 4],
9305        );
9306        let y = conv.forward(&x).unwrap();
9307        assert_eq!(y.shape(), &[1, 1, 4, 4]);
9308        let expected = [
9309            111.5, 178.5, 217.5, 145.5, 231.5, 348.5, 393.5, 252.5, 363.5, 528.5, 573.5, 360.5,
9310            197.5, 274.5, 295.5, 175.5,
9311        ];
9312        assert_close(y.data().unwrap(), &expected, 1e-3);
9313    }
9314
9315    /// Conv2d 'same', EVEN kernel (2,2) — ASYMMETRIC per dim (total=1 ->
9316    /// left/top=0, right/bottom=1). torch oracle:
9317    ///   F.conv2d(arange(1..10).view(1,1,3,3), [[1,2],[3,4]], b=0, "same").
9318    #[test]
9319    fn test_conv2d_same_even_kernel_asymmetric_matches_torch() {
9320        let weight = Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]).unwrap();
9321        let mut conv = Conv2d::<f32>::new(1, 1, (2, 2), (1, 1), (0, 0), false).unwrap();
9322        conv.set_weight(weight).unwrap();
9323        let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9324        let x = t(
9325            &(1..=9).map(|v| v as f32).collect::<Vec<_>>(),
9326            &[1, 1, 3, 3],
9327        );
9328        let y = conv.forward(&x).unwrap();
9329        assert_eq!(y.shape(), &[1, 1, 3, 3]);
9330        let expected = [37.0, 47.0, 21.0, 67.0, 77.0, 33.0, 23.0, 26.0, 9.0];
9331        assert_close(y.data().unwrap(), &expected, 1e-3);
9332    }
9333
9334    /// Conv2d 'same' backward (odd 3x3). torch grad wrt x:
9335    ///   [[12,21,21,16],[27,45,45,33],[27,45,45,33],[24,39,39,28]].
9336    #[test]
9337    fn test_conv2d_same_backward_matches_torch() {
9338        let weight = Parameter::from_slice(
9339            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9340            &[1, 1, 3, 3],
9341        )
9342        .unwrap();
9343        let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), true).unwrap();
9344        conv.set_weight(weight).unwrap();
9345        conv.bias = Some(Parameter::from_slice(&[0.5], &[1]).unwrap());
9346        let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9347        let x = leaf(
9348            &(1..=16).map(|v| v as f32).collect::<Vec<_>>(),
9349            &[1, 1, 4, 4],
9350        );
9351        let y = conv.forward(&x).unwrap();
9352        // out.sum().backward() — grad_output is all-ones (matches the torch oracle).
9353        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9354        ferrotorch_core::backward(&sum).unwrap();
9355        let gx = x.grad().unwrap().expect("input grad must be populated");
9356        assert_eq!(gx.shape(), &[1, 1, 4, 4]);
9357        let expected = [
9358            12.0, 21.0, 21.0, 16.0, 27.0, 45.0, 45.0, 33.0, 27.0, 45.0, 45.0, 33.0, 24.0, 39.0,
9359            39.0, 28.0,
9360        ];
9361        assert_close(gx.data().unwrap(), &expected, 1e-3);
9362    }
9363
9364    /// Conv2d 'valid' == padding 0. torch:
9365    ///   F.conv2d(arange(1..26).view(1,1,5,5), ones(1,1,3,3), padding="valid").
9366    #[test]
9367    fn test_conv2d_valid_matches_torch() {
9368        let weight = Parameter::from_slice(&[1.0; 9], &[1, 1, 3, 3]).unwrap();
9369        let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
9370        conv.set_weight(weight).unwrap();
9371        let conv = conv.with_string_padding(StringPadding::Valid).unwrap();
9372        let x = t(
9373            &(1..=25).map(|v| v as f32).collect::<Vec<_>>(),
9374            &[1, 1, 5, 5],
9375        );
9376        let y = conv.forward(&x).unwrap();
9377        assert_eq!(y.shape(), &[1, 1, 3, 3]);
9378        let expected = [63.0, 72.0, 81.0, 108.0, 117.0, 126.0, 153.0, 162.0, 171.0];
9379        assert_close(y.data().unwrap(), &expected, 1e-3);
9380    }
9381
9382    /// Conv3d 'same', EVEN kernel (2,2,2) — ASYMMETRIC per dim (total=1 ->
9383    /// front/top/left=0, back/bottom/right=1). torch oracle:
9384    ///   F.conv3d(arange(1..28).view(1,1,3,3,3), arange(1..9).view(1,1,2,2,2),
9385    ///            b=0, padding="same").
9386    #[test]
9387    fn test_conv3d_same_even_kernel_asymmetric_matches_torch() {
9388        let weight =
9389            Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 1, 2, 2, 2])
9390                .unwrap();
9391        let mut conv = Conv3d::<f32>::new(1, 1, (2, 2, 2), (1, 1, 1), (0, 0, 0), false).unwrap();
9392        // Direct field write (Conv3d has no `set_weight`; tests share the module).
9393        conv.weight = weight;
9394        let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9395        let x = t(
9396            &(1..=27).map(|v| v as f32).collect::<Vec<_>>(),
9397            &[1, 1, 3, 3, 3],
9398        );
9399        let y = conv.forward(&x).unwrap();
9400        assert_eq!(y.shape(), &[1, 1, 3, 3, 3]);
9401        let expected = [
9402            356.0, 392.0, 186.0, 464.0, 500.0, 234.0, 205.0, 219.0, 99.0, 680.0, 716.0, 330.0,
9403            788.0, 824.0, 378.0, 331.0, 345.0, 153.0, 217.0, 227.0, 93.0, 247.0, 257.0, 105.0,
9404            77.0, 80.0, 27.0,
9405        ];
9406        assert_close(y.data().unwrap(), &expected, 1e-3);
9407    }
9408
9409    /// Conv3d 'same' backward (even 2x2x2). torch grad wrt x (27 values).
9410    #[test]
9411    fn test_conv3d_same_backward_matches_torch() {
9412        let weight =
9413            Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 1, 2, 2, 2])
9414                .unwrap();
9415        let mut conv = Conv3d::<f32>::new(1, 1, (2, 2, 2), (1, 1, 1), (0, 0, 0), false).unwrap();
9416        // Direct field write (Conv3d has no `set_weight`; tests share the module).
9417        conv.weight = weight;
9418        let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9419        let x = leaf(
9420            &(1..=27).map(|v| v as f32).collect::<Vec<_>>(),
9421            &[1, 1, 3, 3, 3],
9422        );
9423        let y = conv.forward(&x).unwrap();
9424        // out.sum().backward() — grad_output is all-ones (matches the torch oracle).
9425        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9426        ferrotorch_core::backward(&sum).unwrap();
9427        let gx = x.grad().unwrap().expect("input grad must be populated");
9428        assert_eq!(gx.shape(), &[1, 1, 3, 3, 3]);
9429        let expected = [
9430            1.0, 3.0, 3.0, 4.0, 10.0, 10.0, 4.0, 10.0, 10.0, 6.0, 14.0, 14.0, 16.0, 36.0, 36.0,
9431            16.0, 36.0, 36.0, 6.0, 14.0, 14.0, 16.0, 36.0, 36.0, 16.0, 36.0, 36.0,
9432        ];
9433        assert_close(gx.data().unwrap(), &expected, 1e-3);
9434    }
9435
9436    /// `padding='same'` with stride>1 is rejected at construction, matching
9437    /// upstream `ValueError("padding='same' is not supported for strided
9438    /// convolutions")` (`conv.py:117-120`).
9439    #[test]
9440    fn test_conv_same_stride_gt1_rejected() {
9441        // Conv1d
9442        let c1 = Conv1d::<f32>::new(1, 1, 3, 2, 0, false)
9443            .unwrap()
9444            .with_string_padding(StringPadding::Same);
9445        let e1 = c1.unwrap_err();
9446        assert!(
9447            format!("{e1}").contains("padding='same' is not supported for strided convolutions"),
9448            "conv1d: {e1}"
9449        );
9450        // Conv2d (stride 2 in one dim)
9451        let c2 = Conv2d::<f32>::new(1, 1, (3, 3), (1, 2), (0, 0), false)
9452            .unwrap()
9453            .with_string_padding(StringPadding::Same);
9454        assert!(
9455            format!("{}", c2.unwrap_err())
9456                .contains("padding='same' is not supported for strided convolutions")
9457        );
9458        // Conv3d
9459        let c3 = Conv3d::<f32>::new(1, 1, (2, 2, 2), (2, 1, 1), (0, 0, 0), false)
9460            .unwrap()
9461            .with_string_padding(StringPadding::Same);
9462        assert!(
9463            format!("{}", c3.unwrap_err())
9464                .contains("padding='same' is not supported for strided convolutions")
9465        );
9466        // 'valid' with stride>1 is fine (no constraint).
9467        assert!(
9468            Conv2d::<f32>::new(1, 1, (3, 3), (2, 2), (0, 0), false)
9469                .unwrap()
9470                .with_string_padding(StringPadding::Valid)
9471                .is_ok()
9472        );
9473    }
9474
9475    // =======================================================================
9476    // Unbatched input (rank D+1)  — #1604
9477    // Oracle values from live torch 2.11.0; the output is rank D+1 and the
9478    // gradient flows back to the unbatched input shape (`batchify` /
9479    // `output.squeeze(0)` at `aten/src/ATen/native/Convolution.cpp:816-831,
9480    // 990-1047`), R-CHAR-3.
9481    // =======================================================================
9482
9483    /// Conv1d unbatched `(C, L)` -> output `(C_out, L_out)` (rank 2).
9484    /// torch: F.conv1d([[1,2,3,4,5]], w=[1,2,3], b=0.5) -> [[14.5,20.5,26.5]].
9485    #[test]
9486    fn test_conv1d_unbatched_forward_matches_torch() {
9487        let conv = conv1d_with_weight(&[1.0, 2.0, 3.0], &[1, 1, 3], 0.5);
9488        let x = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 5]); // (C=1, L=5) unbatched
9489        let y = conv.forward(&x).unwrap();
9490        assert_eq!(y.ndim(), 2, "unbatched output must be rank 2");
9491        assert_eq!(y.shape(), &[1, 3]);
9492        assert_close(y.data().unwrap(), &[14.5, 20.5, 26.5], 1e-4);
9493    }
9494
9495    /// Conv1d unbatched backward: grad shape == unbatched input `(C, L)`.
9496    /// torch grad of sum wrt x = [1, 3, 6, 5, 3].
9497    #[test]
9498    fn test_conv1d_unbatched_backward_matches_torch() {
9499        let conv = conv1d_with_weight(&[1.0, 2.0, 3.0], &[1, 1, 3], 0.5);
9500        let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 5]);
9501        let y = conv.forward(&x).unwrap();
9502        // out.sum().backward() — grad_output is all-ones (matches the torch oracle).
9503        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9504        ferrotorch_core::backward(&sum).unwrap();
9505        let gx = x.grad().unwrap().expect("input grad must be populated");
9506        assert_eq!(gx.shape(), &[1, 5], "grad must match unbatched input shape");
9507        assert_close(gx.data().unwrap(), &[1.0, 3.0, 6.0, 5.0, 3.0], 1e-4);
9508    }
9509
9510    /// Conv2d unbatched `(C, H, W)` -> output `(C_out, H_out, W_out)` (rank 3).
9511    /// torch: F.conv2d(arange(1..26).view(1,5,5), arange(1..10).view(1,1,3,3),
9512    ///   b=0) -> [[[411,456,501],[636,681,726],[861,906,951]]].
9513    #[test]
9514    fn test_conv2d_unbatched_forward_matches_torch() {
9515        let weight = Parameter::from_slice(
9516            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9517            &[1, 1, 3, 3],
9518        )
9519        .unwrap();
9520        let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
9521        conv.set_weight(weight).unwrap();
9522        let x = t(&(1..=25).map(|v| v as f32).collect::<Vec<_>>(), &[1, 5, 5]); // (C,H,W)
9523        let y = conv.forward(&x).unwrap();
9524        assert_eq!(y.ndim(), 3, "unbatched output must be rank 3");
9525        assert_eq!(y.shape(), &[1, 3, 3]);
9526        let expected = [
9527            411.0, 456.0, 501.0, 636.0, 681.0, 726.0, 861.0, 906.0, 951.0,
9528        ];
9529        assert_close(y.data().unwrap(), &expected, 1e-3);
9530    }
9531
9532    /// Conv2d unbatched backward: grad shape == unbatched input `(C, H, W)`.
9533    /// torch grad wrt x (25 values).
9534    #[test]
9535    fn test_conv2d_unbatched_backward_matches_torch() {
9536        let weight = Parameter::from_slice(
9537            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9538            &[1, 1, 3, 3],
9539        )
9540        .unwrap();
9541        let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
9542        conv.set_weight(weight).unwrap();
9543        let x = leaf(&(1..=25).map(|v| v as f32).collect::<Vec<_>>(), &[1, 5, 5]);
9544        let y = conv.forward(&x).unwrap();
9545        // out.sum().backward() — grad_output is all-ones (matches the torch oracle).
9546        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9547        ferrotorch_core::backward(&sum).unwrap();
9548        let gx = x.grad().unwrap().expect("input grad must be populated");
9549        assert_eq!(gx.shape(), &[1, 5, 5], "grad must match unbatched input");
9550        let expected = [
9551            1.0, 3.0, 6.0, 5.0, 3.0, 5.0, 12.0, 21.0, 16.0, 9.0, 12.0, 27.0, 45.0, 33.0, 18.0,
9552            11.0, 24.0, 39.0, 28.0, 15.0, 7.0, 15.0, 24.0, 17.0, 9.0,
9553        ];
9554        assert_close(gx.data().unwrap(), &expected, 1e-3);
9555    }
9556
9557    /// Conv3d unbatched `(C, D, H, W)` -> output rank 4.
9558    /// torch: F.conv3d(arange(1..28).view(1,3,3,3), arange(1..9).view(1,1,2,2,2),
9559    ///   b=0) -> [[[[356,392],[464,500]],[[680,716],[788,824]]]].
9560    #[test]
9561    fn test_conv3d_unbatched_forward_matches_torch() {
9562        let weight =
9563            Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 1, 2, 2, 2])
9564                .unwrap();
9565        let mut conv = Conv3d::<f32>::new(1, 1, (2, 2, 2), (1, 1, 1), (0, 0, 0), false).unwrap();
9566        // Direct field write (Conv3d has no `set_weight`; tests share the module).
9567        conv.weight = weight;
9568        let x = t(
9569            &(1..=27).map(|v| v as f32).collect::<Vec<_>>(),
9570            &[1, 3, 3, 3],
9571        ); // (C,D,H,W)
9572        let y = conv.forward(&x).unwrap();
9573        assert_eq!(y.ndim(), 4, "unbatched output must be rank 4");
9574        assert_eq!(y.shape(), &[1, 2, 2, 2]);
9575        let expected = [356.0, 392.0, 464.0, 500.0, 680.0, 716.0, 788.0, 824.0];
9576        assert_close(y.data().unwrap(), &expected, 1e-3);
9577    }
9578
9579    /// Conv3d unbatched backward: grad shape == unbatched input `(C, D, H, W)`.
9580    /// torch grad wrt x (27 values).
9581    #[test]
9582    fn test_conv3d_unbatched_backward_matches_torch() {
9583        let weight =
9584            Parameter::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 1, 2, 2, 2])
9585                .unwrap();
9586        let mut conv = Conv3d::<f32>::new(1, 1, (2, 2, 2), (1, 1, 1), (0, 0, 0), false).unwrap();
9587        // Direct field write (Conv3d has no `set_weight`; tests share the module).
9588        conv.weight = weight;
9589        let x = leaf(
9590            &(1..=27).map(|v| v as f32).collect::<Vec<_>>(),
9591            &[1, 3, 3, 3],
9592        );
9593        let y = conv.forward(&x).unwrap();
9594        // out.sum().backward() — grad_output is all-ones (matches the torch oracle).
9595        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
9596        ferrotorch_core::backward(&sum).unwrap();
9597        let gx = x.grad().unwrap().expect("input grad must be populated");
9598        assert_eq!(
9599            gx.shape(),
9600            &[1, 3, 3, 3],
9601            "grad must match unbatched input shape"
9602        );
9603        let expected = [
9604            1.0, 3.0, 2.0, 4.0, 10.0, 6.0, 3.0, 7.0, 4.0, 6.0, 14.0, 8.0, 16.0, 36.0, 20.0, 10.0,
9605            22.0, 12.0, 5.0, 11.0, 6.0, 12.0, 26.0, 14.0, 7.0, 15.0, 8.0,
9606        ];
9607        assert_close(gx.data().unwrap(), &expected, 1e-3);
9608    }
9609
9610    /// Unbatched 'same' composes: Conv2d unbatched `(C,H,W)` with `padding=
9611    /// 'same'` keeps the spatial dims and stays rank 3.
9612    #[test]
9613    fn test_conv2d_unbatched_same_composes() {
9614        let weight = Parameter::from_slice(
9615            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0],
9616            &[1, 1, 3, 3],
9617        )
9618        .unwrap();
9619        let mut conv = Conv2d::<f32>::new(1, 1, (3, 3), (1, 1), (0, 0), false).unwrap();
9620        conv.set_weight(weight).unwrap();
9621        let conv = conv.with_string_padding(StringPadding::Same).unwrap();
9622        let x = t(&(1..=16).map(|v| v as f32).collect::<Vec<_>>(), &[1, 4, 4]); // (C,H,W)
9623        let y = conv.forward(&x).unwrap();
9624        assert_eq!(y.ndim(), 3);
9625        assert_eq!(y.shape(), &[1, 4, 4], "same padding preserves spatial dims");
9626    }
9627}