Skip to main content

ferrotorch_nn/
padding.rs

1//! Padding layers: constant, reflection, replication, and zero padding in 1-D, 2-D, 3-D.
2//!
3//! [CL-314] Add Conv3d, ConvTranspose1d/3d, and padding modules
4//!
5//! Each module pads the **last N** dimensions of the input tensor, matching
6//! PyTorch semantics exactly.  Padding tuples specify *(left, right)* for 1-D,
7//! *(left, right, top, bottom)* for 2-D, and
8//! *(left, right, top, bottom, front, back)* for 3-D.
9//!
10//! ## REQ status (per `.design/ferrotorch-nn/padding.md`)
11//!
12//! | REQ | Status | Evidence |
13//! |---|---|---|
14//! | REQ-1 | SHIPPED | impl: `pub enum PaddingMode` here with 4 variants `Zeros` / `Reflect` / `Replicate` / `Circular`; non-test consumer: `ferrotorch-nn/src/conv.rs` uses `PaddingMode` as the `Conv{1,2,3}d` `padding_mode` field — the non-`Zeros` forward branch routes through `functional_pad_{1,2,3}d` (wiring landed in #1443), and `ConvTranspose{1,2,3}d::with_padding_mode` matches on it to reject non-`Zeros`. |
15//! | REQ-2 | SHIPPED | impl: the grow-only `functional_pad_1d` / `functional_pad_2d` / `functional_pad_3d` entry points here dispatch on `PaddingMode`; the `Zeros`/constant arm routes through the crop-capable `functional_pad_1d_signed` / `functional_pad_2d_signed` / `functional_pad_3d_signed` (`isize` pads) which support NEGATIVE (crop) pads + mixed signs for `mode="constant"` via `pad_nd_signed_constant` + `PadNdSignedBackward`, mirroring `constant_pad_nd` at upstream `aten/src/ATen/native/PadNd.cpp:29-108` (#1611). Non-test consumer: the `usize` `functional_pad_{1,2,3}d` consume the signed entrypoints in production (the `Zeros` arm); `ferrotorch-nn/src/conv.rs` calls `functional_pad_{1,2,3}d` for the conv pre-pad; `ferrotorch-nn/src/functional.rs` re-exposes these as `nn::functional::pad`. |
16//! | REQ-3 | SHIPPED | impl: `pub struct ConstantPad{1,2,3}d<T: Float>` here, mirroring `torch/nn/modules/padding.py` constant-pad family; non-test consumer: `pub use` in `lib.rs` exposes them to external crates; the vision-model code uses `ConstantPad2d` via the `lib.rs` re-export for padding non-square inputs. |
17//! | REQ-4 | SHIPPED | impl: `pub struct ZeroPad{1,2,3}d<T: Float>` here; non-test consumer: `pub use` in `lib.rs` exposes them. |
18//! | REQ-5 | SHIPPED | impl: `pub struct ReflectionPad{1,2,3}d<T: Float>` here with reflect-overflow check inside `pad_*d_reflect`; non-test consumer: `pub use` in `lib.rs`; reflection padding is the standard for U-nets and image-translation models. |
19//! | REQ-6 | SHIPPED | impl: `pub struct ReplicationPad{1,2,3}d<T: Float>` here; non-test consumer: `pub use` in `lib.rs`. |
20//! | REQ-7 | SHIPPED | impl: `pub struct CircularPad{1,2,3}d<T: Float>` here; non-test consumer: `pub use` in `lib.rs`. |
21//! | REQ-8 | SHIPPED | impl: `macro_rules! impl_padding_module` here generates the `Module<T>` impls for all 12 structs; non-test consumer: `ferrotorch_optim` walks `Module::parameters()` of containers that include padding layers (every padding layer returns the empty parameter list, which is the correct behavior). |
22//! | REQ-9 | NOT-STARTED | blocker #1441 (umbrella) — parity-sweep runner arms absent for all 6 padding ops. The impl is end-to-end verified by 40+ lib tests; only the runner-arm wiring is missing. |
23
24use std::sync::Arc;
25
26use ferrotorch_core::autograd::no_grad::is_grad_enabled;
27use ferrotorch_core::storage::TensorStorage;
28use ferrotorch_core::tensor::{GradFn, Tensor};
29use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float};
30
31use crate::module::Module;
32use crate::parameter::Parameter;
33
34// ---------------------------------------------------------------------------
35// Padding mode enum (used by conv layers with padding_mode)
36// ---------------------------------------------------------------------------
37
38/// Padding mode for convolution layers.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum PaddingMode {
41    /// Zero padding (default).
42    Zeros,
43    /// Reflect padding.
44    Reflect,
45    /// Replicate padding (edge padding).
46    Replicate,
47    /// Circular padding (wrap-around).
48    Circular,
49}
50
51// ---------------------------------------------------------------------------
52// Low-level pad helpers (operate on raw data)
53// ---------------------------------------------------------------------------
54
55/// Pad the last dimension of a contiguous tensor.
56///
57/// `shape` has at least 1 dimension. The padding values `(left, right)` are
58/// added to dimension `ndim-1`.
59fn pad_1d_constant<T: Float>(
60    data: &[T],
61    shape: &[usize],
62    pad_left: usize,
63    pad_right: usize,
64    value: T,
65) -> (Vec<T>, Vec<usize>) {
66    let ndim = shape.len();
67    let inner = shape[ndim - 1];
68    let new_inner = inner + pad_left + pad_right;
69
70    // Number of "rows" = product of all dimensions except the last.
71    let rows: usize = shape[..ndim - 1].iter().product();
72    let rows = if rows == 0 { 1 } else { rows };
73
74    let mut out = vec![value; rows * new_inner];
75    // Degenerate input (numel 0 — e.g. an empty data buffer paired with a
76    // non-empty declared shape, or `inner == 0`): there is no source data to
77    // copy in. Mirror upstream `aten/src/ATen/native/PadNd.cpp:94-106`, which
78    // `fill_(value)`s the output then `copy_`s the source — a no-op for a
79    // zero-element input — leaving the correctly-shaped, value-filled output.
80    // The guard prevents the out-of-bounds slice on `data` (#1551).
81    if !data.is_empty() {
82        for r in 0..rows {
83            let src_start = r * inner;
84            let dst_start = r * new_inner + pad_left;
85            out[dst_start..dst_start + inner].copy_from_slice(&data[src_start..src_start + inner]);
86        }
87    }
88
89    let mut new_shape = shape.to_vec();
90    new_shape[ndim - 1] = new_inner;
91    (out, new_shape)
92}
93
94/// Pad the last 2 dimensions of a contiguous tensor with a constant value.
95fn pad_2d_constant<T: Float>(
96    data: &[T],
97    shape: &[usize],
98    pad_left: usize,
99    pad_right: usize,
100    pad_top: usize,
101    pad_bottom: usize,
102    value: T,
103) -> (Vec<T>, Vec<usize>) {
104    let ndim = shape.len();
105    let h = shape[ndim - 2];
106    let w = shape[ndim - 1];
107    let new_h = h + pad_top + pad_bottom;
108    let new_w = w + pad_left + pad_right;
109
110    let outer: usize = shape[..ndim - 2].iter().product();
111    let outer = if outer == 0 { 1 } else { outer };
112
113    let mut out = vec![value; outer * new_h * new_w];
114    // Degenerate input (numel 0): no source data to copy in. Same rationale as
115    // `pad_1d_constant` — mirror upstream `PadNd.cpp:94-106` (#1551).
116    if !data.is_empty() {
117        for o in 0..outer {
118            for row in 0..h {
119                let src_off = o * h * w + row * w;
120                let dst_off = o * new_h * new_w + (row + pad_top) * new_w + pad_left;
121                out[dst_off..dst_off + w].copy_from_slice(&data[src_off..src_off + w]);
122            }
123        }
124    }
125
126    let mut new_shape = shape.to_vec();
127    new_shape[ndim - 2] = new_h;
128    new_shape[ndim - 1] = new_w;
129    (out, new_shape)
130}
131
132/// Pad the last 3 dimensions of a contiguous tensor with a constant value.
133// Internal kernel: signature mirrors PyTorch's `F.pad` 3-axis layout
134// (left, right, top, bottom, front, back); a config struct adds nothing.
135#[allow(clippy::too_many_arguments)]
136fn pad_3d_constant<T: Float>(
137    data: &[T],
138    shape: &[usize],
139    pad_left: usize,
140    pad_right: usize,
141    pad_top: usize,
142    pad_bottom: usize,
143    pad_front: usize,
144    pad_back: usize,
145    value: T,
146) -> (Vec<T>, Vec<usize>) {
147    let ndim = shape.len();
148    let d = shape[ndim - 3];
149    let h = shape[ndim - 2];
150    let w = shape[ndim - 1];
151    let new_d = d + pad_front + pad_back;
152    let new_h = h + pad_top + pad_bottom;
153    let new_w = w + pad_left + pad_right;
154
155    let outer: usize = shape[..ndim - 3].iter().product();
156    let outer = if outer == 0 { 1 } else { outer };
157
158    let mut out = vec![value; outer * new_d * new_h * new_w];
159    // Degenerate input (numel 0): no source data to copy in. Same rationale as
160    // `pad_1d_constant` — mirror upstream `PadNd.cpp:94-106` (#1551).
161    if !data.is_empty() {
162        for o in 0..outer {
163            for dep in 0..d {
164                for row in 0..h {
165                    let src_off = o * d * h * w + dep * h * w + row * w;
166                    let dst_off = o * new_d * new_h * new_w
167                        + (dep + pad_front) * new_h * new_w
168                        + (row + pad_top) * new_w
169                        + pad_left;
170                    out[dst_off..dst_off + w].copy_from_slice(&data[src_off..src_off + w]);
171                }
172            }
173        }
174    }
175
176    let mut new_shape = shape.to_vec();
177    new_shape[ndim - 3] = new_d;
178    new_shape[ndim - 2] = new_h;
179    new_shape[ndim - 1] = new_w;
180    (out, new_shape)
181}
182
183// ---------------------------------------------------------------------------
184// Reflection padding helpers
185// ---------------------------------------------------------------------------
186
187/// Reflect-pad the last dimension.
188fn pad_1d_reflect<T: Float>(
189    data: &[T],
190    shape: &[usize],
191    pad_left: usize,
192    pad_right: usize,
193) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
194    let ndim = shape.len();
195    let inner = shape[ndim - 1];
196    if pad_left >= inner || pad_right >= inner {
197        return Err(FerrotorchError::InvalidArgument {
198            message: format!(
199                "Reflection padding ({pad_left}, {pad_right}) must be less than input size ({inner})"
200            ),
201        });
202    }
203    let new_inner = inner + pad_left + pad_right;
204    let rows: usize = shape[..ndim - 1].iter().copied().product::<usize>().max(1);
205
206    let zero = <T as num_traits::Zero>::zero();
207    let mut out = vec![zero; rows * new_inner];
208    for r in 0..rows {
209        let src = &data[r * inner..(r + 1) * inner];
210        let dst = &mut out[r * new_inner..(r + 1) * new_inner];
211        // Left reflection
212        for i in 0..pad_left {
213            dst[pad_left - 1 - i] = src[i + 1];
214        }
215        // Copy original
216        dst[pad_left..pad_left + inner].copy_from_slice(src);
217        // Right reflection
218        for i in 0..pad_right {
219            dst[pad_left + inner + i] = src[inner - 2 - i];
220        }
221    }
222
223    let mut new_shape = shape.to_vec();
224    new_shape[ndim - 1] = new_inner;
225    Ok((out, new_shape))
226}
227
228/// Reflect-pad the last 2 dimensions.
229fn pad_2d_reflect<T: Float>(
230    data: &[T],
231    shape: &[usize],
232    pad_left: usize,
233    pad_right: usize,
234    pad_top: usize,
235    pad_bottom: usize,
236) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
237    let ndim = shape.len();
238    let h = shape[ndim - 2];
239    let w = shape[ndim - 1];
240    if pad_left >= w || pad_right >= w || pad_top >= h || pad_bottom >= h {
241        return Err(FerrotorchError::InvalidArgument {
242            message: format!(
243                "Reflection padding ({pad_left}, {pad_right}, {pad_top}, {pad_bottom}) must be less than input size ({h}, {w})"
244            ),
245        });
246    }
247    let new_h = h + pad_top + pad_bottom;
248    let new_w = w + pad_left + pad_right;
249    let outer: usize = shape[..ndim - 2].iter().copied().product::<usize>().max(1);
250
251    let zero = <T as num_traits::Zero>::zero();
252    let mut out = vec![zero; outer * new_h * new_w];
253
254    for o in 0..outer {
255        let src_base = o * h * w;
256        let dst_base = o * new_h * new_w;
257
258        for new_row in 0..new_h {
259            // Map new_row to source row via reflection
260            let src_row = if new_row < pad_top {
261                pad_top - new_row
262            } else if new_row >= pad_top + h {
263                h - 2 - (new_row - pad_top - h)
264            } else {
265                new_row - pad_top
266            };
267
268            for new_col in 0..new_w {
269                let src_col = if new_col < pad_left {
270                    pad_left - new_col
271                } else if new_col >= pad_left + w {
272                    w - 2 - (new_col - pad_left - w)
273                } else {
274                    new_col - pad_left
275                };
276
277                out[dst_base + new_row * new_w + new_col] = data[src_base + src_row * w + src_col];
278            }
279        }
280    }
281
282    let mut new_shape = shape.to_vec();
283    new_shape[ndim - 2] = new_h;
284    new_shape[ndim - 1] = new_w;
285    Ok((out, new_shape))
286}
287
288/// Reflect-pad the last 3 dimensions.
289// Internal kernel: same 3-axis pad descriptor as `pad_3d_constant`.
290#[allow(clippy::too_many_arguments)]
291fn pad_3d_reflect<T: Float>(
292    data: &[T],
293    shape: &[usize],
294    pad_left: usize,
295    pad_right: usize,
296    pad_top: usize,
297    pad_bottom: usize,
298    pad_front: usize,
299    pad_back: usize,
300) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
301    let ndim = shape.len();
302    let d = shape[ndim - 3];
303    let h = shape[ndim - 2];
304    let w = shape[ndim - 1];
305    if pad_left >= w
306        || pad_right >= w
307        || pad_top >= h
308        || pad_bottom >= h
309        || pad_front >= d
310        || pad_back >= d
311    {
312        return Err(FerrotorchError::InvalidArgument {
313            message: "Reflection padding must be less than corresponding input dimension".into(),
314        });
315    }
316    let new_d = d + pad_front + pad_back;
317    let new_h = h + pad_top + pad_bottom;
318    let new_w = w + pad_left + pad_right;
319    let outer: usize = shape[..ndim - 3].iter().copied().product::<usize>().max(1);
320
321    let zero = <T as num_traits::Zero>::zero();
322    let mut out = vec![zero; outer * new_d * new_h * new_w];
323
324    for o in 0..outer {
325        let src_base = o * d * h * w;
326        let dst_base = o * new_d * new_h * new_w;
327
328        for nd in 0..new_d {
329            let sd = if nd < pad_front {
330                pad_front - nd
331            } else if nd >= pad_front + d {
332                d - 2 - (nd - pad_front - d)
333            } else {
334                nd - pad_front
335            };
336            for nh in 0..new_h {
337                let sh = if nh < pad_top {
338                    pad_top - nh
339                } else if nh >= pad_top + h {
340                    h - 2 - (nh - pad_top - h)
341                } else {
342                    nh - pad_top
343                };
344                for nw in 0..new_w {
345                    let sw = if nw < pad_left {
346                        pad_left - nw
347                    } else if nw >= pad_left + w {
348                        w - 2 - (nw - pad_left - w)
349                    } else {
350                        nw - pad_left
351                    };
352                    out[dst_base + nd * new_h * new_w + nh * new_w + nw] =
353                        data[src_base + sd * h * w + sh * w + sw];
354                }
355            }
356        }
357    }
358
359    let mut new_shape = shape.to_vec();
360    new_shape[ndim - 3] = new_d;
361    new_shape[ndim - 2] = new_h;
362    new_shape[ndim - 1] = new_w;
363    Ok((out, new_shape))
364}
365
366// ---------------------------------------------------------------------------
367// Replication padding helpers
368// ---------------------------------------------------------------------------
369
370/// Replicate-pad the last dimension (clamp to edges).
371fn pad_1d_replicate<T: Float>(
372    data: &[T],
373    shape: &[usize],
374    pad_left: usize,
375    pad_right: usize,
376) -> (Vec<T>, Vec<usize>) {
377    let ndim = shape.len();
378    let inner = shape[ndim - 1];
379    let new_inner = inner + pad_left + pad_right;
380    let rows: usize = shape[..ndim - 1].iter().copied().product::<usize>().max(1);
381
382    let zero = <T as num_traits::Zero>::zero();
383    let mut out = vec![zero; rows * new_inner];
384    for r in 0..rows {
385        let src = &data[r * inner..(r + 1) * inner];
386        let dst = &mut out[r * new_inner..(r + 1) * new_inner];
387        for (i, d) in dst.iter_mut().enumerate() {
388            let src_idx = if i < pad_left {
389                0
390            } else if i >= pad_left + inner {
391                inner - 1
392            } else {
393                i - pad_left
394            };
395            *d = src[src_idx];
396        }
397    }
398
399    let mut new_shape = shape.to_vec();
400    new_shape[ndim - 1] = new_inner;
401    (out, new_shape)
402}
403
404/// Replicate-pad the last 2 dimensions.
405fn pad_2d_replicate<T: Float>(
406    data: &[T],
407    shape: &[usize],
408    pad_left: usize,
409    pad_right: usize,
410    pad_top: usize,
411    pad_bottom: usize,
412) -> (Vec<T>, Vec<usize>) {
413    let ndim = shape.len();
414    let h = shape[ndim - 2];
415    let w = shape[ndim - 1];
416    let new_h = h + pad_top + pad_bottom;
417    let new_w = w + pad_left + pad_right;
418    let outer: usize = shape[..ndim - 2].iter().copied().product::<usize>().max(1);
419
420    let zero = <T as num_traits::Zero>::zero();
421    let mut out = vec![zero; outer * new_h * new_w];
422
423    for o in 0..outer {
424        let src_base = o * h * w;
425        let dst_base = o * new_h * new_w;
426        for nr in 0..new_h {
427            let sr = nr.saturating_sub(pad_top).min(h - 1);
428            for nc in 0..new_w {
429                let sc = nc.saturating_sub(pad_left).min(w - 1);
430                out[dst_base + nr * new_w + nc] = data[src_base + sr * w + sc];
431            }
432        }
433    }
434
435    let mut new_shape = shape.to_vec();
436    new_shape[ndim - 2] = new_h;
437    new_shape[ndim - 1] = new_w;
438    (out, new_shape)
439}
440
441/// Replicate-pad the last 3 dimensions.
442// Internal kernel: same 3-axis pad descriptor as `pad_3d_constant`.
443#[allow(clippy::too_many_arguments)]
444fn pad_3d_replicate<T: Float>(
445    data: &[T],
446    shape: &[usize],
447    pad_left: usize,
448    pad_right: usize,
449    pad_top: usize,
450    pad_bottom: usize,
451    pad_front: usize,
452    pad_back: usize,
453) -> (Vec<T>, Vec<usize>) {
454    let ndim = shape.len();
455    let d = shape[ndim - 3];
456    let h = shape[ndim - 2];
457    let w = shape[ndim - 1];
458    let new_d = d + pad_front + pad_back;
459    let new_h = h + pad_top + pad_bottom;
460    let new_w = w + pad_left + pad_right;
461    let outer: usize = shape[..ndim - 3].iter().copied().product::<usize>().max(1);
462
463    let zero = <T as num_traits::Zero>::zero();
464    let mut out = vec![zero; outer * new_d * new_h * new_w];
465
466    for o in 0..outer {
467        let src_base = o * d * h * w;
468        let dst_base = o * new_d * new_h * new_w;
469        for nd in 0..new_d {
470            let sd = nd.saturating_sub(pad_front).min(d - 1);
471            for nh in 0..new_h {
472                let sh = nh.saturating_sub(pad_top).min(h - 1);
473                for nw in 0..new_w {
474                    let sw = nw.saturating_sub(pad_left).min(w - 1);
475                    out[dst_base + nd * new_h * new_w + nh * new_w + nw] =
476                        data[src_base + sd * h * w + sh * w + sw];
477                }
478            }
479        }
480    }
481
482    let mut new_shape = shape.to_vec();
483    new_shape[ndim - 3] = new_d;
484    new_shape[ndim - 2] = new_h;
485    new_shape[ndim - 1] = new_w;
486    (out, new_shape)
487}
488
489// ---------------------------------------------------------------------------
490// Circular padding helpers
491// ---------------------------------------------------------------------------
492
493/// Reject an all-non-negative circular pad that wraps around more than once.
494///
495/// The positive-only `pad_*_circular` helpers gather via `rem_euclid`, which
496/// silently wraps a pad strictly larger than the axis size MULTIPLE times
497/// (e.g. `circular [0,3]` on size 2 -> `[1,2,1,2,1]`). Upstream
498/// `_pad_circular_symint` rejects this at `aten/src/ATen/native/PadNd.cpp:142`:
499/// `TORCH_CHECK(pad_l <= size && pad_r <= size, "Padding value causes wrapping
500/// around more than once.")`. For a non-negative pad the net extent is always
501/// `>= size > 0`, so `:142` is the only check that can fire — mirror it here so
502/// the positive circular path matches torch's accept/reject (`pad <= size`).
503fn check_circular_positive(axes: &[(usize, usize)]) -> FerrotorchResult<()> {
504    for (idx, &(size, pad)) in axes.iter().enumerate() {
505        if pad > size {
506            return Err(FerrotorchError::InvalidArgument {
507                message: format!(
508                    "Circular padding {pad} on axis (size {size}, position {idx}) causes wrapping around more than once (pad must be <= size)"
509                ),
510            });
511        }
512    }
513    Ok(())
514}
515
516/// Circular-pad the last dimension (wrap-around).
517fn pad_1d_circular<T: Float>(
518    data: &[T],
519    shape: &[usize],
520    pad_left: usize,
521    pad_right: usize,
522) -> (Vec<T>, Vec<usize>) {
523    let ndim = shape.len();
524    let inner = shape[ndim - 1];
525    let new_inner = inner + pad_left + pad_right;
526    let rows: usize = shape[..ndim - 1].iter().copied().product::<usize>().max(1);
527
528    let zero = <T as num_traits::Zero>::zero();
529    let mut out = vec![zero; rows * new_inner];
530    for r in 0..rows {
531        let src = &data[r * inner..(r + 1) * inner];
532        let dst = &mut out[r * new_inner..(r + 1) * new_inner];
533        for (i, d) in dst.iter_mut().enumerate() {
534            // Map to source via modulo
535            let src_idx = ((i as isize - pad_left as isize).rem_euclid(inner as isize)) as usize;
536            *d = src[src_idx];
537        }
538    }
539
540    let mut new_shape = shape.to_vec();
541    new_shape[ndim - 1] = new_inner;
542    (out, new_shape)
543}
544
545/// Circular-pad the last 2 dimensions.
546fn pad_2d_circular<T: Float>(
547    data: &[T],
548    shape: &[usize],
549    pad_left: usize,
550    pad_right: usize,
551    pad_top: usize,
552    pad_bottom: usize,
553) -> (Vec<T>, Vec<usize>) {
554    let ndim = shape.len();
555    let h = shape[ndim - 2];
556    let w = shape[ndim - 1];
557    let new_h = h + pad_top + pad_bottom;
558    let new_w = w + pad_left + pad_right;
559    let outer: usize = shape[..ndim - 2].iter().copied().product::<usize>().max(1);
560
561    let zero = <T as num_traits::Zero>::zero();
562    let mut out = vec![zero; outer * new_h * new_w];
563
564    for o in 0..outer {
565        let src_base = o * h * w;
566        let dst_base = o * new_h * new_w;
567        for nr in 0..new_h {
568            let sr = ((nr as isize - pad_top as isize).rem_euclid(h as isize)) as usize;
569            for nc in 0..new_w {
570                let sc = ((nc as isize - pad_left as isize).rem_euclid(w as isize)) as usize;
571                out[dst_base + nr * new_w + nc] = data[src_base + sr * w + sc];
572            }
573        }
574    }
575
576    let mut new_shape = shape.to_vec();
577    new_shape[ndim - 2] = new_h;
578    new_shape[ndim - 1] = new_w;
579    (out, new_shape)
580}
581
582/// Circular-pad the last 3 dimensions.
583// Internal kernel: same 3-axis pad descriptor as `pad_3d_constant`.
584#[allow(clippy::too_many_arguments)]
585fn pad_3d_circular<T: Float>(
586    data: &[T],
587    shape: &[usize],
588    pad_left: usize,
589    pad_right: usize,
590    pad_top: usize,
591    pad_bottom: usize,
592    pad_front: usize,
593    pad_back: usize,
594) -> (Vec<T>, Vec<usize>) {
595    let ndim = shape.len();
596    let d = shape[ndim - 3];
597    let h = shape[ndim - 2];
598    let w = shape[ndim - 1];
599    let new_d = d + pad_front + pad_back;
600    let new_h = h + pad_top + pad_bottom;
601    let new_w = w + pad_left + pad_right;
602    let outer: usize = shape[..ndim - 3].iter().copied().product::<usize>().max(1);
603
604    let zero = <T as num_traits::Zero>::zero();
605    let mut out = vec![zero; outer * new_d * new_h * new_w];
606
607    for o in 0..outer {
608        let src_base = o * d * h * w;
609        let dst_base = o * new_d * new_h * new_w;
610        for nd in 0..new_d {
611            let sd = ((nd as isize - pad_front as isize).rem_euclid(d as isize)) as usize;
612            for nh in 0..new_h {
613                let sh = ((nh as isize - pad_top as isize).rem_euclid(h as isize)) as usize;
614                for nw in 0..new_w {
615                    let sw = ((nw as isize - pad_left as isize).rem_euclid(w as isize)) as usize;
616                    out[dst_base + nd * new_h * new_w + nh * new_w + nw] =
617                        data[src_base + sd * h * w + sh * w + sw];
618                }
619            }
620        }
621    }
622
623    let mut new_shape = shape.to_vec();
624    new_shape[ndim - 3] = new_d;
625    new_shape[ndim - 2] = new_h;
626    new_shape[ndim - 1] = new_w;
627    (out, new_shape)
628}
629
630// ===========================================================================
631// Public functional API — apply arbitrary padding to a Tensor
632// ===========================================================================
633
634// ---------------------------------------------------------------------------
635// Autograd for the 1-D functional pad path (used by Conv1d's non-zero
636// padding_mode pre-pad). Same gather/scatter-add adjoint as the 2-D case;
637// see the `Pad2dBackward` block below for the full derivation. A pad that
638// returns `requires_grad = false` severs autograd — the #1550 bug class that
639// the 2-D path already fixed; the 1-D path needs the same `Pad1dBackward`
640// node so Conv1d's input gradient flows through the reflect/replicate/circular
641// pre-pad. Mirrors upstream `torch/nn/modules/conv.py:367-371` routing
642// non-zero modes through the differentiable `F.pad`.
643// ---------------------------------------------------------------------------
644
645/// For an output element at `new_idx` in a 1-D pad, return the linear index
646/// into the (single) source row, or `None` if the element comes from the
647/// constant fill (Zeros mode) and has no source.
648fn src_index_1d(mode: PaddingMode, new_idx: usize, inner: usize, pad_left: usize) -> Option<usize> {
649    let s: usize = match mode {
650        PaddingMode::Zeros => {
651            if new_idx < pad_left || new_idx >= pad_left + inner {
652                return None;
653            }
654            new_idx - pad_left
655        }
656        PaddingMode::Reflect => {
657            if new_idx < pad_left {
658                pad_left - new_idx
659            } else if new_idx >= pad_left + inner {
660                inner - 2 - (new_idx - pad_left - inner)
661            } else {
662                new_idx - pad_left
663            }
664        }
665        PaddingMode::Replicate => new_idx.saturating_sub(pad_left).min(inner - 1),
666        PaddingMode::Circular => {
667            ((new_idx as isize - pad_left as isize).rem_euclid(inner as isize)) as usize
668        }
669    };
670    Some(s)
671}
672
673/// Backward node for the 1-D functional pad. Scatter-adds the output gradient
674/// back onto the unpadded input row using the per-output source-index map.
675#[derive(Debug)]
676struct Pad1dBackward<T: Float> {
677    input: Tensor<T>,
678    input_shape: Vec<usize>,
679    mode: PaddingMode,
680    pad_left: usize,
681}
682
683impl<T: Float> GradFn<T> for Pad1dBackward<T> {
684    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
685        if !self.input.requires_grad() {
686            return Ok(vec![None]);
687        }
688        let ndim = self.input_shape.len();
689        let inner = self.input_shape[ndim - 1];
690        let rows: usize = self.input_shape[..ndim - 1]
691            .iter()
692            .copied()
693            .product::<usize>()
694            .max(1);
695
696        let go_shape = grad_output.shape();
697        let new_inner = go_shape[ndim - 1];
698
699        // The backward runs on host: scatter-add is data-dependent over the
700        // index map. `data_vec` materialises the (possibly GPU) grad to CPU.
701        let go = grad_output.data_vec()?;
702        let zero = <T as num_traits::Zero>::zero();
703        let mut grad_in = vec![zero; rows * inner];
704
705        for r in 0..rows {
706            let go_base = r * new_inner;
707            let gi_base = r * inner;
708            for ni in 0..new_inner {
709                if let Some(src) = src_index_1d(self.mode, ni, inner, self.pad_left) {
710                    grad_in[gi_base + src] += go[go_base + ni];
711                }
712            }
713        }
714
715        let grad_input =
716            Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
717        Ok(vec![Some(grad_input)])
718    }
719
720    fn inputs(&self) -> Vec<&Tensor<T>> {
721        vec![&self.input]
722    }
723
724    fn name(&self) -> &'static str {
725        "Pad1dBackward"
726    }
727}
728
729/// Apply padding to the last dimension of a tensor using the given mode.
730///
731/// This is the functional version used internally by conv layers with
732/// `padding_mode`.
733///
734/// When `input` requires grad (and grad tracking is enabled) the returned
735/// tensor carries a [`Pad1dBackward`] node so gradients flow back to `input`,
736/// matching the differentiable `F.pad` that `torch/nn/modules/conv.py`
737/// `_conv_forward` routes non-zero `padding_mode`s through (Conv1d at
738/// `conv.py:367-371`).
739pub fn functional_pad_1d<T: Float>(
740    input: &Tensor<T>,
741    pad_left: usize,
742    pad_right: usize,
743    mode: PaddingMode,
744    value: T,
745) -> FerrotorchResult<Tensor<T>> {
746    // `Zeros` is the runner's mapping for torch `mode="constant"`; route it
747    // through the crop-capable signed constant path — the single source of
748    // truth for constant padding, mirroring torch dispatching `mode="constant"`
749    // through `constant_pad_nd` (`aten/src/ATen/native/PadNd.cpp:214-215`). For
750    // a non-negative `usize` pad the signed forward is byte-identical to the old
751    // `pad_1d_constant` and its `PadNdSignedBackward` scatter-add equals the old
752    // `Pad1dBackward` adjoint; the `value` fill (#1553) is preserved.
753    if mode == PaddingMode::Zeros {
754        return functional_pad_1d_signed(input, pad_left as isize, pad_right as isize, mode, value);
755    }
756
757    let data = input.data_vec()?;
758    let shape = input.shape();
759    let input_shape = shape.to_vec();
760    // The `Zeros` (constant) arm is dispatched above through the crop-capable
761    // signed path; the remaining gather modes never crop and keep their
762    // existing positive-only helpers + `Pad1dBackward` adjoint.
763    let (out_data, new_shape) = match mode {
764        PaddingMode::Reflect => pad_1d_reflect(&data, shape, pad_left, pad_right)?,
765        PaddingMode::Replicate => pad_1d_replicate(&data, shape, pad_left, pad_right),
766        PaddingMode::Circular => {
767            let inner = shape[shape.len() - 1];
768            check_circular_positive(&[(inner, pad_left), (inner, pad_right)])?;
769            pad_1d_circular(&data, shape, pad_left, pad_right)
770        }
771        PaddingMode::Zeros => {
772            return functional_pad_1d_signed(
773                input,
774                pad_left as isize,
775                pad_right as isize,
776                mode,
777                value,
778            );
779        }
780    };
781
782    // Grad path: attach Pad1dBackward so the autograd graph stays connected.
783    // Without this the prior `from_storage(.., false)` severed it (#1550 bug
784    // class), and Conv1d's input gradient would not flow through the non-zero
785    // padding_mode pre-pad.
786    if is_grad_enabled() && input.requires_grad() {
787        let grad_fn = Arc::new(Pad1dBackward {
788            input: input.clone(),
789            input_shape,
790            mode,
791            pad_left,
792        });
793        return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
794    }
795
796    Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false)
797}
798
799// ---------------------------------------------------------------------------
800// Autograd for the 2-D functional pad path (used by Conv2d's non-zero
801// padding_mode pre-pad). Every pad mode is a pure *gather*:
802//   out[k] = input[src_index_2d(k)]   (or 0 for the out-of-bounds Zeros case).
803// The adjoint (VJP) of a gather is a scatter-add into the unpadded input:
804//   grad_input[src_index_2d(k)] += grad_output[k].
805// This single rule is correct for ALL modes — Zeros (interior crop, padded
806// outputs have no source so contribute nothing), Reflect (the reflected
807// boundary source indices repeat, so their grads fold/accumulate back onto
808// the mirrored interior positions), Replicate (the edge source index repeats,
809// summing into the edge), and Circular (wrapped source indices accumulate
810// around). Mirrors upstream `torch/nn/modules/conv.py:367-371` routing
811// non-zero modes through the differentiable `F.pad`.
812// ---------------------------------------------------------------------------
813
814/// For an output element at `(new_row, new_col)` in a 2-D pad, return the
815/// linear index `sr * w + sc` into the (single) source plane, or `None` if the
816/// element comes from the constant fill (Zeros mode) and has no source.
817fn src_index_2d(
818    mode: PaddingMode,
819    new_row: usize,
820    new_col: usize,
821    h: usize,
822    w: usize,
823    pad_left: usize,
824    pad_top: usize,
825) -> Option<usize> {
826    let sr: usize = match mode {
827        PaddingMode::Zeros => {
828            if new_row < pad_top || new_row >= pad_top + h {
829                return None;
830            }
831            new_row - pad_top
832        }
833        PaddingMode::Reflect => {
834            if new_row < pad_top {
835                pad_top - new_row
836            } else if new_row >= pad_top + h {
837                h - 2 - (new_row - pad_top - h)
838            } else {
839                new_row - pad_top
840            }
841        }
842        PaddingMode::Replicate => new_row.saturating_sub(pad_top).min(h - 1),
843        PaddingMode::Circular => {
844            ((new_row as isize - pad_top as isize).rem_euclid(h as isize)) as usize
845        }
846    };
847    let sc: usize = match mode {
848        PaddingMode::Zeros => {
849            if new_col < pad_left || new_col >= pad_left + w {
850                return None;
851            }
852            new_col - pad_left
853        }
854        PaddingMode::Reflect => {
855            if new_col < pad_left {
856                pad_left - new_col
857            } else if new_col >= pad_left + w {
858                w - 2 - (new_col - pad_left - w)
859            } else {
860                new_col - pad_left
861            }
862        }
863        PaddingMode::Replicate => new_col.saturating_sub(pad_left).min(w - 1),
864        PaddingMode::Circular => {
865            ((new_col as isize - pad_left as isize).rem_euclid(w as isize)) as usize
866        }
867    };
868    Some(sr * w + sc)
869}
870
871/// Backward node for the 2-D functional pad. Scatter-adds the output gradient
872/// back onto the unpadded input plane using the per-output source-index map.
873#[derive(Debug)]
874struct Pad2dBackward<T: Float> {
875    input: Tensor<T>,
876    input_shape: Vec<usize>,
877    mode: PaddingMode,
878    pad_left: usize,
879    pad_top: usize,
880}
881
882impl<T: Float> GradFn<T> for Pad2dBackward<T> {
883    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
884        if !self.input.requires_grad() {
885            return Ok(vec![None]);
886        }
887        let ndim = self.input_shape.len();
888        let h = self.input_shape[ndim - 2];
889        let w = self.input_shape[ndim - 1];
890        let outer: usize = self.input_shape[..ndim - 2]
891            .iter()
892            .copied()
893            .product::<usize>()
894            .max(1);
895
896        let go_shape = grad_output.shape();
897        let new_h = go_shape[ndim - 2];
898        let new_w = go_shape[ndim - 1];
899
900        // The backward runs on host: scatter-add is data-dependent over the
901        // index map. `data_vec` materialises the (possibly GPU) grad to CPU.
902        let go = grad_output.data_vec()?;
903        let zero = <T as num_traits::Zero>::zero();
904        let mut grad_in = vec![zero; outer * h * w];
905
906        for o in 0..outer {
907            let go_base = o * new_h * new_w;
908            let gi_base = o * h * w;
909            for nr in 0..new_h {
910                for nc in 0..new_w {
911                    if let Some(src) =
912                        src_index_2d(self.mode, nr, nc, h, w, self.pad_left, self.pad_top)
913                    {
914                        grad_in[gi_base + src] += go[go_base + nr * new_w + nc];
915                    }
916                }
917            }
918        }
919
920        let grad_input =
921            Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
922        Ok(vec![Some(grad_input)])
923    }
924
925    fn inputs(&self) -> Vec<&Tensor<T>> {
926        vec![&self.input]
927    }
928
929    fn name(&self) -> &'static str {
930        "Pad2dBackward"
931    }
932}
933
934/// Apply padding to the last 2 dimensions of a tensor using the given mode.
935///
936/// When `input` requires grad (and grad tracking is enabled) the returned
937/// tensor carries a [`Pad2dBackward`] node so gradients flow back to `input`,
938/// matching the differentiable `F.pad` that `torch/nn/modules/conv.py`
939/// `_conv_forward` routes non-zero `padding_mode`s through.
940pub fn functional_pad_2d<T: Float>(
941    input: &Tensor<T>,
942    pad_left: usize,
943    pad_right: usize,
944    pad_top: usize,
945    pad_bottom: usize,
946    mode: PaddingMode,
947    value: T,
948) -> FerrotorchResult<Tensor<T>> {
949    // `Zeros` (torch `mode="constant"`) routes through the crop-capable signed
950    // path — see the `functional_pad_1d` note. The `value` fill (#1553) is
951    // preserved; for non-negative `usize` pads the result is byte-identical.
952    if mode == PaddingMode::Zeros {
953        return functional_pad_2d_signed(
954            input,
955            pad_left as isize,
956            pad_right as isize,
957            pad_top as isize,
958            pad_bottom as isize,
959            mode,
960            value,
961        );
962    }
963
964    let data = input.data_vec()?;
965    let shape = input.shape();
966    let input_shape = shape.to_vec();
967    let (out_data, new_shape) = match mode {
968        PaddingMode::Reflect => {
969            pad_2d_reflect(&data, shape, pad_left, pad_right, pad_top, pad_bottom)?
970        }
971        PaddingMode::Replicate => {
972            pad_2d_replicate(&data, shape, pad_left, pad_right, pad_top, pad_bottom)
973        }
974        PaddingMode::Circular => {
975            let nd = shape.len();
976            let (h, w) = (shape[nd - 2], shape[nd - 1]);
977            check_circular_positive(&[
978                (w, pad_left),
979                (w, pad_right),
980                (h, pad_top),
981                (h, pad_bottom),
982            ])?;
983            pad_2d_circular(&data, shape, pad_left, pad_right, pad_top, pad_bottom)
984        }
985        PaddingMode::Zeros => {
986            return functional_pad_2d_signed(
987                input,
988                pad_left as isize,
989                pad_right as isize,
990                pad_top as isize,
991                pad_bottom as isize,
992                mode,
993                value,
994            );
995        }
996    };
997
998    // Grad path: attach Pad2dBackward so the autograd graph stays connected
999    // (the prior `from_storage(..., false)` severed it — #1550).
1000    if is_grad_enabled() && input.requires_grad() {
1001        let grad_fn = Arc::new(Pad2dBackward {
1002            input: input.clone(),
1003            input_shape,
1004            mode,
1005            pad_left,
1006            pad_top,
1007        });
1008        return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
1009    }
1010
1011    Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false)
1012}
1013
1014// ---------------------------------------------------------------------------
1015// Autograd for the 3-D functional pad path (used by Conv3d's non-zero
1016// padding_mode pre-pad). Same gather/scatter-add adjoint as the 1-D / 2-D
1017// cases; see the `Pad2dBackward` block above for the full derivation. Without
1018// the backward node, the pad returns `requires_grad = false` and severs
1019// autograd — the #1550 bug class. Mirrors upstream `torch/nn/modules/conv.py`
1020// `Conv3d._conv_forward` (`conv.py:717-721`) routing non-zero modes through
1021// the differentiable `F.pad`.
1022// ---------------------------------------------------------------------------
1023
1024/// For an output element at `(nd, nh, nw)` in a 3-D pad, return the linear
1025/// index `sd*H*W + sh*W + sw` into the (single) source volume, or `None` if
1026/// the element comes from the constant fill (Zeros mode) and has no source.
1027// Internal helper: the 3-axis pad descriptor (d/h/w + per-axis pad) carries
1028// proportionally more arguments than the 1-D/2-D variants.
1029#[allow(clippy::too_many_arguments)]
1030fn src_index_3d(
1031    mode: PaddingMode,
1032    nd: usize,
1033    nh: usize,
1034    nw: usize,
1035    d: usize,
1036    h: usize,
1037    w: usize,
1038    pad_left: usize,
1039    pad_top: usize,
1040    pad_front: usize,
1041) -> Option<usize> {
1042    // Axis-wise source resolver shared across all three spatial axes. Returns
1043    // `None` only for the Zeros mode out-of-bounds case (constant fill).
1044    fn axis(mode: PaddingMode, new_idx: usize, size: usize, pad_lo: usize) -> Option<usize> {
1045        let s = match mode {
1046            PaddingMode::Zeros => {
1047                if new_idx < pad_lo || new_idx >= pad_lo + size {
1048                    return None;
1049                }
1050                new_idx - pad_lo
1051            }
1052            PaddingMode::Reflect => {
1053                if new_idx < pad_lo {
1054                    pad_lo - new_idx
1055                } else if new_idx >= pad_lo + size {
1056                    size - 2 - (new_idx - pad_lo - size)
1057                } else {
1058                    new_idx - pad_lo
1059                }
1060            }
1061            PaddingMode::Replicate => new_idx.saturating_sub(pad_lo).min(size - 1),
1062            PaddingMode::Circular => {
1063                ((new_idx as isize - pad_lo as isize).rem_euclid(size as isize)) as usize
1064            }
1065        };
1066        Some(s)
1067    }
1068    let sd = axis(mode, nd, d, pad_front)?;
1069    let sh = axis(mode, nh, h, pad_top)?;
1070    let sw = axis(mode, nw, w, pad_left)?;
1071    Some(sd * h * w + sh * w + sw)
1072}
1073
1074/// Backward node for the 3-D functional pad. Scatter-adds the output gradient
1075/// back onto the unpadded input volume using the per-output source-index map.
1076#[derive(Debug)]
1077struct Pad3dBackward<T: Float> {
1078    input: Tensor<T>,
1079    input_shape: Vec<usize>,
1080    mode: PaddingMode,
1081    pad_left: usize,
1082    pad_top: usize,
1083    pad_front: usize,
1084}
1085
1086impl<T: Float> GradFn<T> for Pad3dBackward<T> {
1087    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1088        if !self.input.requires_grad() {
1089            return Ok(vec![None]);
1090        }
1091        let ndim = self.input_shape.len();
1092        let d = self.input_shape[ndim - 3];
1093        let h = self.input_shape[ndim - 2];
1094        let w = self.input_shape[ndim - 1];
1095        let outer: usize = self.input_shape[..ndim - 3]
1096            .iter()
1097            .copied()
1098            .product::<usize>()
1099            .max(1);
1100
1101        let go_shape = grad_output.shape();
1102        let new_d = go_shape[ndim - 3];
1103        let new_h = go_shape[ndim - 2];
1104        let new_w = go_shape[ndim - 1];
1105
1106        // The backward runs on host: scatter-add is data-dependent over the
1107        // index map. `data_vec` materialises the (possibly GPU) grad to CPU.
1108        let go = grad_output.data_vec()?;
1109        let zero = <T as num_traits::Zero>::zero();
1110        let mut grad_in = vec![zero; outer * d * h * w];
1111
1112        for o in 0..outer {
1113            let go_base = o * new_d * new_h * new_w;
1114            let gi_base = o * d * h * w;
1115            for ndp in 0..new_d {
1116                for nhp in 0..new_h {
1117                    for nwp in 0..new_w {
1118                        if let Some(src) = src_index_3d(
1119                            self.mode,
1120                            ndp,
1121                            nhp,
1122                            nwp,
1123                            d,
1124                            h,
1125                            w,
1126                            self.pad_left,
1127                            self.pad_top,
1128                            self.pad_front,
1129                        ) {
1130                            grad_in[gi_base + src] +=
1131                                go[go_base + ndp * new_h * new_w + nhp * new_w + nwp];
1132                        }
1133                    }
1134                }
1135            }
1136        }
1137
1138        let grad_input =
1139            Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
1140        Ok(vec![Some(grad_input)])
1141    }
1142
1143    fn inputs(&self) -> Vec<&Tensor<T>> {
1144        vec![&self.input]
1145    }
1146
1147    fn name(&self) -> &'static str {
1148        "Pad3dBackward"
1149    }
1150}
1151
1152/// Apply padding to the last 3 dimensions of a tensor using the given mode.
1153///
1154/// When `input` requires grad (and grad tracking is enabled) the returned
1155/// tensor carries a [`Pad3dBackward`] node so gradients flow back to `input`,
1156/// matching the differentiable `F.pad` that `torch/nn/modules/conv.py`
1157/// `Conv3d._conv_forward` routes non-zero `padding_mode`s through
1158/// (`conv.py:717-721`).
1159// Public API: matches PyTorch's `torch.nn.functional.pad` signature for the
1160// 3-axis case (input + 6 pad amounts + mode + value); divergence would
1161// break parity with the upstream reference.
1162#[allow(clippy::too_many_arguments)]
1163pub fn functional_pad_3d<T: Float>(
1164    input: &Tensor<T>,
1165    pad_left: usize,
1166    pad_right: usize,
1167    pad_top: usize,
1168    pad_bottom: usize,
1169    pad_front: usize,
1170    pad_back: usize,
1171    mode: PaddingMode,
1172    value: T,
1173) -> FerrotorchResult<Tensor<T>> {
1174    // `Zeros` (torch `mode="constant"`) routes through the crop-capable signed
1175    // path — see the `functional_pad_1d` note. The `value` fill (#1553) is
1176    // preserved; for non-negative `usize` pads the result is byte-identical.
1177    if mode == PaddingMode::Zeros {
1178        return functional_pad_3d_signed(
1179            input,
1180            pad_left as isize,
1181            pad_right as isize,
1182            pad_top as isize,
1183            pad_bottom as isize,
1184            pad_front as isize,
1185            pad_back as isize,
1186            mode,
1187            value,
1188        );
1189    }
1190
1191    let data = input.data_vec()?;
1192    let shape = input.shape();
1193    let input_shape = shape.to_vec();
1194    let (out_data, new_shape) = match mode {
1195        PaddingMode::Reflect => pad_3d_reflect(
1196            &data, shape, pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back,
1197        )?,
1198        PaddingMode::Replicate => pad_3d_replicate(
1199            &data, shape, pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back,
1200        ),
1201        PaddingMode::Circular => {
1202            let nd = shape.len();
1203            let (d, h, w) = (shape[nd - 3], shape[nd - 2], shape[nd - 1]);
1204            check_circular_positive(&[
1205                (w, pad_left),
1206                (w, pad_right),
1207                (h, pad_top),
1208                (h, pad_bottom),
1209                (d, pad_front),
1210                (d, pad_back),
1211            ])?;
1212            pad_3d_circular(
1213                &data, shape, pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back,
1214            )
1215        }
1216        PaddingMode::Zeros => {
1217            return functional_pad_3d_signed(
1218                input,
1219                pad_left as isize,
1220                pad_right as isize,
1221                pad_top as isize,
1222                pad_bottom as isize,
1223                pad_front as isize,
1224                pad_back as isize,
1225                mode,
1226                value,
1227            );
1228        }
1229    };
1230
1231    // Grad path: attach Pad3dBackward so the autograd graph stays connected.
1232    // Without this the prior `from_storage(.., false)` severed it (#1550 bug
1233    // class), and Conv3d's input gradient would not flow through the non-zero
1234    // padding_mode pre-pad.
1235    if is_grad_enabled() && input.requires_grad() {
1236        let grad_fn = Arc::new(Pad3dBackward {
1237            input: input.clone(),
1238            input_shape,
1239            mode,
1240            pad_left,
1241            pad_top,
1242            pad_front,
1243        });
1244        return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
1245    }
1246
1247    Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false)
1248}
1249
1250// ===========================================================================
1251// Signed (crop-capable) functional pad — torch `constant_pad_nd` negative pad
1252// ===========================================================================
1253//
1254// `torch.nn.functional.pad` accepts NEGATIVE pad amounts: a negative pad on a
1255// side CROPS (removes) `|pad|` elements from that side instead of adding. ALL
1256// modes support this — upstream `aten/src/ATen/native/PadNd.cpp:207-242`
1257// (`_pad_enum_symint`) routes `constant` through `constant_pad_nd` (which
1258// narrows for negatives) and `reflect`/`replicate`/`circular` straight to the
1259// native `reflection_pad*` / `replication_pad*` / `_pad_circular` kernels, which
1260// compute `output = input + pad_l + pad_r` (a negative pad narrows the side)
1261// and gather with offset `max(0,-pad) - max(0,pad)` (ReflectionPad.cpp:46,
1262// PaddingKernel.cpp:63-65, PadNd.cpp:158-159). The non-constant modes still
1263// reject a non-zero `value` (PadNd.cpp:217-219). This signed-constant gather is
1264// the `PaddingMode::Zeros` forward; the other modes compose crop-then-pad
1265// (`functional_pad_nd_signed`), which is byte-identical to their native kernels.
1266//
1267// Forward (PadNd.cpp:29-108): for each padded dim with signed pads `(lo, hi)`
1268// the cropped input is `narrow(i, -lo, size+lo)` (when `lo<0`) then
1269// `narrow(i, 0, size'+hi)` (when `hi<0`); the output of size
1270// `new = size + lo + hi` is `fill_(value)`d and the cropped input copied into
1271// the `max(lo,0)` offset window (PadNd.cpp:94-106). Equivalently, an output
1272// index `o` reads source index `s = o - lo`: when `0 <= s < size` it is real
1273// data, otherwise (only possible for the POSITIVE-pad region) it is the `value`
1274// fill. This one rule handles MIXED signs per-dim correctly.
1275//
1276// Over-crop: torch's `narrow` rejects a negative length
1277// ("narrow(): length must be non-negative", from PadNd.cpp:49 / :54), and
1278// PadNd.cpp:76 `TORCH_CHECK(new_dim >= 0)`. We mirror BOTH: a left crop may not
1279// exceed `size`, and a right crop may not exceed the post-left-crop size — i.e.
1280// `size + min(lo,0) >= 0` AND `size + min(lo,0) + min(hi,0) >= 0`. A net size of
1281// exactly 0 is allowed (torch returns an empty dim, e.g. `F.pad(x3, [-1,-2])`).
1282//
1283// Backward: the adjoint of a crop-or-pad gather is a scatter-add into the
1284// (full, original-size) input — `grad_input[o - lo] += grad_output[o]` for the
1285// in-bounds outputs. Cropped-away positions receive no contribution (grad 0),
1286// matching torch's `constant_pad_nd` backward being itself a `constant_pad_nd`
1287// with negated pads.
1288
1289/// Resolve, for a single axis, the source index a padded/cropped output index
1290/// reads from. Returns `None` for the constant-fill region (an output position
1291/// in the POSITIVE-pad area that has no source element). `lo` is the signed pad
1292/// on the low side of this axis.
1293#[inline]
1294fn signed_axis_src(new_idx: usize, size: usize, lo: isize) -> Option<usize> {
1295    let s = new_idx as isize - lo;
1296    if s >= 0 && (s as usize) < size {
1297        Some(s as usize)
1298    } else {
1299        None
1300    }
1301}
1302
1303/// Validate the signed pads for one axis against torch's sequential-`narrow`
1304/// crop rule and return the new axis size. Errors when a crop removes more than
1305/// the (running) axis size — mirroring torch's
1306/// "narrow(): length must be non-negative" / `TORCH_CHECK(new_dim >= 0)`.
1307fn signed_axis_new_size(
1308    size: usize,
1309    lo: isize,
1310    hi: isize,
1311    axis_label: &str,
1312) -> FerrotorchResult<usize> {
1313    // Left crop applies first (PadNd.cpp:49): narrow length `size + lo` must be
1314    // non-negative when `lo < 0`.
1315    let after_left: isize = if lo < 0 {
1316        size as isize + lo
1317    } else {
1318        size as isize
1319    };
1320    if after_left < 0 {
1321        return Err(FerrotorchError::InvalidArgument {
1322            message: format!(
1323                "constant pad: negative padding {lo} on {axis_label} crops more than the dimension size {size} (narrow length would be negative)"
1324            ),
1325        });
1326    }
1327    // Right crop applies to the post-left size (PadNd.cpp:54).
1328    let after_right: isize = if hi < 0 { after_left + hi } else { after_left };
1329    if after_right < 0 {
1330        return Err(FerrotorchError::InvalidArgument {
1331            message: format!(
1332                "constant pad: negative padding ({lo}, {hi}) on {axis_label} crops more than the dimension size {size}, resulting in a negative output size"
1333            ),
1334        });
1335    }
1336    // The actual new size also adds the POSITIVE side of each pad back in.
1337    Ok((after_right + lo.max(0) + hi.max(0)) as usize)
1338}
1339
1340/// Generic crop-capable constant pad over the last `npad` dimensions.
1341///
1342/// `pads` is `[lo_0, hi_0, lo_1, hi_1, ...]` ordered from the LAST padded axis
1343/// inward (i.e. matching the `(left, right, top, bottom, front, back)`
1344/// flattened layout the public entrypoints use). Returns `(data, new_shape)`.
1345fn pad_nd_signed_constant<T: Float>(
1346    data: &[T],
1347    shape: &[usize],
1348    pads: &[(isize, isize)],
1349    value: T,
1350) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
1351    let ndim = shape.len();
1352    let npad = pads.len();
1353    // `pads[0]` targets the LAST axis; map axis k (0-based from the last padded
1354    // axis) to absolute dim `ndim - 1 - k`.
1355    let mut new_shape = shape.to_vec();
1356    let mut new_sizes = vec![0usize; npad]; // new_sizes[k] for axis ndim-1-k
1357    for (k, &(lo, hi)) in pads.iter().enumerate() {
1358        let dim = ndim - 1 - k;
1359        let new_size = signed_axis_new_size(shape[dim], lo, hi, &format!("dimension {dim}"))?;
1360        new_sizes[k] = new_size;
1361        new_shape[dim] = new_size;
1362    }
1363
1364    // Outer dims (everything before the first padded axis) are untouched.
1365    let first_padded = ndim - npad;
1366    let outer: usize = shape[..first_padded]
1367        .iter()
1368        .copied()
1369        .product::<usize>()
1370        .max(1);
1371
1372    let new_total: usize = new_shape.iter().copied().product();
1373    let mut out = vec![value; new_total];
1374
1375    // Degenerate input (numel 0 — e.g. shape `[0, 3]`: empty data buffer with a
1376    // non-empty declared dim): no source data to gather. Mirror upstream
1377    // `aten/src/ATen/native/PadNd.cpp:94-106`, which `fill_(value)`s the output
1378    // then `copy_`s the (empty) source — a no-op — leaving the value-filled
1379    // output. The guard prevents an out-of-bounds index into the empty `data`
1380    // (same #1551 bug class the positive-only `pad_*d_constant` helpers guard).
1381    if data.is_empty() {
1382        return Ok((out, new_shape));
1383    }
1384
1385    // Per-element gather over the padded sub-volume. `npad` is at most 3 here,
1386    // so a small fixed-stride walk over the last axes is sufficient and clear.
1387    // Strides within the (single outer slice of the) input / output.
1388    let in_inner: usize = shape[first_padded..].iter().product();
1389    let out_inner: usize = new_shape[first_padded..].iter().product();
1390
1391    // Source coordinate buffer reused per output element.
1392    for o in 0..outer {
1393        let in_base = o * in_inner;
1394        let out_base = o * out_inner;
1395        for flat in 0..out_inner {
1396            // Decode `flat` into per-axis output coords (last axis fastest).
1397            let mut rem = flat;
1398            let mut src_lin = 0usize;
1399            let mut src_stride = 1usize;
1400            let mut missing = false;
1401            // Walk axes from last (k=0) to first padded (k=npad-1).
1402            for k in 0..npad {
1403                let dim = ndim - 1 - k;
1404                let axis_new = new_shape[dim];
1405                let coord = rem % axis_new;
1406                rem /= axis_new;
1407                let lo = pads[k].0;
1408                match signed_axis_src(coord, shape[dim], lo) {
1409                    Some(s) => {
1410                        src_lin += s * src_stride;
1411                        src_stride *= shape[dim];
1412                    }
1413                    None => {
1414                        missing = true;
1415                        break;
1416                    }
1417                }
1418            }
1419            if !missing {
1420                out[out_base + flat] = data[in_base + src_lin];
1421            }
1422            // else: leave the `value` fill already in place.
1423        }
1424    }
1425
1426    Ok((out, new_shape))
1427}
1428
1429/// Backward node for the signed (crop-capable) constant functional pad. The
1430/// adjoint of the crop/pad gather is a scatter-add into the original-size
1431/// input: `grad_input[o - lo] += grad_output[o]` for in-bounds outputs. Cropped
1432/// positions get no contribution (grad 0). Mirrors torch's `constant_pad_nd`
1433/// backward (itself a `constant_pad_nd` with negated pads).
1434#[derive(Debug)]
1435struct PadNdSignedBackward<T: Float> {
1436    input: Tensor<T>,
1437    input_shape: Vec<usize>,
1438    /// `(lo, hi)` per padded axis, ordered LAST axis first (same as the forward).
1439    pads: Vec<(isize, isize)>,
1440}
1441
1442impl<T: Float> GradFn<T> for PadNdSignedBackward<T> {
1443    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1444        if !self.input.requires_grad() {
1445            return Ok(vec![None]);
1446        }
1447        let ndim = self.input_shape.len();
1448        let npad = self.pads.len();
1449        let first_padded = ndim - npad;
1450        let outer: usize = self.input_shape[..first_padded]
1451            .iter()
1452            .copied()
1453            .product::<usize>()
1454            .max(1);
1455        let in_inner: usize = self.input_shape[first_padded..].iter().product();
1456
1457        let go_shape = grad_output.shape();
1458        let out_inner: usize = go_shape[first_padded..].iter().product();
1459
1460        // The backward runs on host: scatter-add is data-dependent over the
1461        // index map. `data_vec` materialises the (possibly GPU) grad to CPU.
1462        let go = grad_output.data_vec()?;
1463        let zero = <T as num_traits::Zero>::zero();
1464        let mut grad_in = vec![zero; outer * in_inner];
1465
1466        for o in 0..outer {
1467            let in_base = o * in_inner;
1468            let out_base = o * out_inner;
1469            for flat in 0..out_inner {
1470                let mut rem = flat;
1471                let mut src_lin = 0usize;
1472                let mut src_stride = 1usize;
1473                let mut missing = false;
1474                for k in 0..npad {
1475                    let dim = ndim - 1 - k;
1476                    let axis_new = go_shape[dim];
1477                    let coord = rem % axis_new;
1478                    rem /= axis_new;
1479                    let lo = self.pads[k].0;
1480                    match signed_axis_src(coord, self.input_shape[dim], lo) {
1481                        Some(s) => {
1482                            src_lin += s * src_stride;
1483                            src_stride *= self.input_shape[dim];
1484                        }
1485                        None => {
1486                            missing = true;
1487                            break;
1488                        }
1489                    }
1490                }
1491                if !missing {
1492                    grad_in[in_base + src_lin] += go[out_base + flat];
1493                }
1494            }
1495        }
1496
1497        let grad_input =
1498            Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
1499        Ok(vec![Some(grad_input)])
1500    }
1501
1502    fn inputs(&self) -> Vec<&Tensor<T>> {
1503        vec![&self.input]
1504    }
1505
1506    fn name(&self) -> &'static str {
1507        "PadNdSignedBackward"
1508    }
1509}
1510
1511/// Apply the all-non-negative pad part of `pads` under a non-`Zeros` mode by
1512/// delegating to the positive-only helpers, so reflect/replicate/circular keep
1513/// their exact gather + autograd behaviour. `pads` is LAST axis first.
1514fn functional_pad_nd_positive<T: Float>(
1515    input: &Tensor<T>,
1516    pads: &[(isize, isize)],
1517    mode: PaddingMode,
1518    value: T,
1519) -> FerrotorchResult<Tensor<T>> {
1520    match pads.len() {
1521        1 => functional_pad_1d(input, pads[0].0 as usize, pads[0].1 as usize, mode, value),
1522        2 => functional_pad_2d(
1523            input,
1524            pads[0].0 as usize,
1525            pads[0].1 as usize,
1526            pads[1].0 as usize,
1527            pads[1].1 as usize,
1528            mode,
1529            value,
1530        ),
1531        3 => functional_pad_3d(
1532            input,
1533            pads[0].0 as usize,
1534            pads[0].1 as usize,
1535            pads[1].0 as usize,
1536            pads[1].1 as usize,
1537            pads[2].0 as usize,
1538            pads[2].1 as usize,
1539            mode,
1540            value,
1541        ),
1542        other => Err(FerrotorchError::InvalidArgument {
1543            message: format!("functional_pad_nd_signed supports 1-3 padded dims, got {other}"),
1544        }),
1545    }
1546}
1547
1548/// Unified reflect index map matching upstream
1549/// `aten/src/ATen/native/cpu/PaddingKernel.cpp:63-80`. `j` is the output
1550/// position, `size` is the ORIGINAL input size on this axis, and `pad` is the
1551/// signed LOW-side pad. The window offset is
1552/// `offset = max(0, -pad) - max(0, pad)` (`PaddingKernel.cpp:63-65`); the
1553/// reflected index is then read as `i + offset` from the ORIGINAL input
1554/// (`PaddingKernel.cpp:71-80`). This reads the original window directly, so a
1555/// positive pad on a cropped side correctly reaches elements a crop-first pass
1556/// would have discarded. Caller guarantees the resolved index is in
1557/// `0..size` via the reflect legality check (`|pad| < size` per side).
1558#[inline]
1559fn reflect_axis_src(j: usize, size: usize, pad: isize) -> usize {
1560    let j = j as isize;
1561    let size_i = size as isize;
1562    let offset = 0i64.max(-(pad as i64)) - 0i64.max(pad as i64);
1563    let offset = offset as isize;
1564    let i = if j < pad {
1565        pad * 2 - j
1566    } else if j >= pad && j < size_i + pad {
1567        j
1568    } else {
1569        (size_i + pad - 1) * 2 - j
1570    };
1571    (i + offset) as usize
1572}
1573
1574/// Unified replicate index map matching upstream `ReplicationPad::index`
1575/// (`aten/src/ATen/native/cpu/PaddingKernel.cpp:84-95`). `j` is the output
1576/// position, `size` is the ORIGINAL input size on this axis, and `pad` is the
1577/// signed LOW-side pad. The window offset is
1578/// `offset = max(0, -pad) - max(0, pad)` (`PaddingKernel.cpp:63-65`); the
1579/// CLAMPED index is then read as `i + offset` from the ORIGINAL input window
1580/// (`PaddingKernel.cpp:87-94`): a position before the (possibly cropped) window
1581/// clamps to the left boundary `pad`, a position past it clamps to the right
1582/// boundary `size + pad - 1`, and an interior position reads `j`. Because the
1583/// gather always resolves against the ORIGINAL window, an over-crop that leaves
1584/// a zero-size axis still reads the preserved edge element — no `inner - 1`
1585/// underflow, no panic (#1625, R-CODE-2). For a non-negative pad this is
1586/// byte-identical to the old crop-then-pad clamp. Caller guarantees `size >= 1`
1587/// (an empty original axis cannot be replicated; the legality check rejects it).
1588#[inline]
1589fn replicate_axis_src(j: usize, size: usize, pad: isize) -> usize {
1590    let j = j as isize;
1591    let size_i = size as isize;
1592    let offset = 0i64.max(-(pad as i64)) - 0i64.max(pad as i64);
1593    let offset = offset as isize;
1594    let i = if j < pad {
1595        pad
1596    } else if j >= pad && j < size_i + pad {
1597        j
1598    } else {
1599        size_i + pad - 1
1600    };
1601    (i + offset) as usize
1602}
1603
1604/// Circular index map mirroring `_pad_circular`'s slice-copy gather
1605/// (`aten/src/ATen/native/PadNd.cpp:148-187`). The kernel first copies a
1606/// (possibly cropped) center slice `out[max(lo,0) .. out_w-max(hi,0)]` from
1607/// `in[max(-lo,0) .. size-max(-hi,0)]`, then wraps the left pad from the END of
1608/// the output and the right pad from the START. So a wrap reads from the
1609/// CROPPED center — NOT a plain modulo against the original window (which only
1610/// coincides when there is no crop). `j` is the output position, `size` the
1611/// ORIGINAL input size, `(lo, hi)` the signed pads on this axis. Returns the
1612/// RAW (signed) source index into the ORIGINAL input; it may fall outside
1613/// `0..size` for an illegal pad — `circular_axis_new_size` pre-validates every
1614/// index lies in `0..size` before the gather casts it to `usize`. Only called
1615/// for `out_w >= 1` (an empty `out_w == 0` axis runs the gather zero times).
1616#[inline]
1617fn circular_axis_src(j: usize, size: usize, lo: isize, hi: isize) -> isize {
1618    let j = j as isize;
1619    let size_i = size as isize;
1620    let out_w = size_i + lo + hi;
1621    let lo_pos = lo.max(0);
1622    let hi_pos = hi.max(0);
1623    // Resolve `j` to a center-region output index (left/right wraps copy from
1624    // the already-written center), then map that center index to the input.
1625    let center = if j < lo_pos {
1626        // Left wrap (`pad_l > 0`): out[0..lo] <- out[out_w-lo-hi_pos .. out_w-hi_pos].
1627        out_w - lo - hi_pos + j
1628    } else if j >= out_w - hi_pos {
1629        // Right wrap (`pad_r > 0`): out[out_w-hi .. out_w] <- out[lo_pos .. lo_pos+hi].
1630        lo_pos + (j - (out_w - hi))
1631    } else {
1632        j
1633    };
1634    // Center → input: in[max(-lo,0) + (center - max(lo,0))].
1635    lo.min(0).abs() + (center - lo_pos)
1636}
1637
1638/// PER-AXIS circular-pad legality, returning the new axis extent
1639/// (`size + lo + hi`, which may be `0` for a net-zero crop → an empty dim).
1640///
1641/// This mirrors EXACTLY the two `TORCH_CHECK`s inside `_pad_circular_symint`'s
1642/// shape loop (`aten/src/ATen/native/PadNd.cpp:140-145`) — and ONLY those. The
1643/// center slice-copy (`:158-161`) and the wrap gather (`:169-187`) are NOT
1644/// per-axis legality: torch first allocates the FULL N-D output
1645/// (`:148 auto out = self.new_empty_symint(out_shape, ...)`) and only THEN does
1646/// the per-axis `copy_`. Those `copy_`s are validated in
1647/// [`circular_axis_validate_nonempty`], gated on the WHOLE output being
1648/// non-empty (when any axis is `0`, `out` has `numel 0` and every `copy_` is a
1649/// no-op — see the holistic restructure in [`pad_nd_signed_reflect_circular`]).
1650///
1651/// - `:140-142` `TORCH_CHECK(pad_l <= size && pad_r <= size, "Padding value
1652///   causes wrapping around more than once.")` — a pad strictly greater than
1653///   `size` wraps more than once → `Err`. This is the ONLY per-axis legality.
1654/// - `:143-145` `TORCH_CHECK(out_shape >= 0, "Negative padding value is
1655///   resulting in an empty dimension")` — a negative net extent → `Err`; a net
1656///   extent of EXACTLY `0` is allowed (an empty `[..,0]` dim, like
1657///   `constant_pad_nd`), distinct from reflect which demands `>= 1`.
1658fn circular_axis_legality(
1659    size: usize,
1660    lo: isize,
1661    hi: isize,
1662    dim: usize,
1663) -> FerrotorchResult<usize> {
1664    let size_i = size as isize;
1665    // `:140-142` — a pad larger than the dim wraps around more than once.
1666    if lo > size_i || hi > size_i {
1667        return Err(FerrotorchError::InvalidArgument {
1668            message: format!(
1669                "Circular padding ({lo}, {hi}) causes wrapping around more than once on dimension {dim} (size {size})"
1670            ),
1671        });
1672    }
1673    // `:143-145` — a negative net extent is an error; net zero is an empty dim.
1674    let out_w = size_i + lo + hi;
1675    if out_w < 0 {
1676        return Err(FerrotorchError::InvalidArgument {
1677            message: format!(
1678                "Circular padding ({lo}, {hi}) on dimension {dim} of size {size} results in a negative output size {out_w} (empty dimension)"
1679            ),
1680        });
1681    }
1682    Ok(out_w as usize)
1683}
1684
1685/// Normalize a `tensor.slice(dim, start, end)` to a clamped `[start, end)`
1686/// half-open range over a `length`-element axis, mirroring torch's
1687/// `slice_symint` index normalization (negative indices `+= length`, then clamp
1688/// to `[0, length]`). Used by [`circular_slicecopy_block`] to model every
1689/// `slice_symint` in `_pad_circular_symint` (`PadNd.cpp:148-187`).
1690#[inline]
1691fn circular_slice_range(length: isize, mut start: isize, mut end: isize) -> (usize, usize) {
1692    if start < 0 {
1693        start += length;
1694    }
1695    if end < 0 {
1696        end += length;
1697    }
1698    start = start.clamp(0, length);
1699    end = end.clamp(0, length);
1700    if end < start {
1701        end = start;
1702    }
1703    (start as usize, end as usize)
1704}
1705
1706/// HOLISTIC faithful simulation of torch's `_pad_circular_symint` allocate-then-
1707/// copy algorithm (`aten/src/ATen/native/PadNd.cpp:148-187`) over the last
1708/// `npad` dims of ONE outer batch block. This replaces the prior per-axis
1709/// wrap-OOB / center-copy pre-validation (which rejected an axis whose ISOLATED
1710/// wrap was OOB even when a SIBLING axis had already emptied the whole output —
1711/// the #1628 cross-axis net-zero divergence). Instead of validating each axis in
1712/// isolation, we reproduce torch's exact sequence on the full N-D output buffer:
1713///
1714/// - `:148` `auto out = self.new_empty_symint(out_shape)` — a buffer with an
1715///   `init` mask (all `false`); uninitialized cells are tracked so an over-crop
1716///   that leaves a final cell unwritten is detected as the R-DEV-6 carve-out.
1717/// - `:154-161` ONE center `copy_`: narrow `out` and `self` on every padded dim
1718///   by `slice(dim, max(pad,0), …)` / `slice(dim, max(-pad,0), …)`, then copy.
1719///   `copy_` errors unless the source broadcasts to the destination shape (per
1720///   dim: sizes equal OR source size 1); a mismatch is a torch `RuntimeError`.
1721/// - `:169-187` the left/right wrap `copy_`s, each reading from `out` LIVE
1722///   (`in_slice = out.slice_symint(...)` aliases the buffer being written, so a
1723///   wrap reads cells the center or an earlier wrap just wrote — `:163-165`
1724///   "Corners will be written more than once"). Same broadcast-legality gate.
1725///
1726/// Because the wraps read from `out`, an axis whose isolated wrap would be OOB
1727/// is harmless when a different axis emptied the output (every `copy_` is then a
1728/// no-op over the empty extent), and torch's well-defined cross-axis wraps that
1729/// the prior per-axis check rejected are now reproduced byte-for-byte. After all
1730/// copies, any cell still uninitialized means torch read uninitialized memory
1731/// there (no reproducible byte-for-byte contract — R-DEV-6); ferrotorch rejects
1732/// such cases cleanly rather than returning nondeterministic garbage (R-CODE-2:
1733/// no panic). The legality `:140-145` is already enforced by
1734/// [`circular_axis_legality`] before this runs.
1735fn circular_slicecopy_block<T: Float>(
1736    in_block: &[T],
1737    in_inner_shape: &[usize],
1738    out_inner_shape: &[usize],
1739    pads: &[(isize, isize)],
1740) -> FerrotorchResult<Vec<T>> {
1741    let npad = pads.len();
1742    let ninner = in_inner_shape.len();
1743    let out_total: usize = out_inner_shape.iter().product();
1744    let zero = <T as num_traits::Zero>::zero();
1745    let mut out = vec![zero; out_total];
1746    let mut init = vec![false; out_total];
1747
1748    // Row-major strides for the inner (padded-region) coordinate space.
1749    let mut in_strides = vec![1usize; ninner];
1750    let mut out_strides = vec![1usize; ninner];
1751    for d in (0..ninner.saturating_sub(1)).rev() {
1752        in_strides[d] = in_strides[d + 1] * in_inner_shape[d + 1];
1753        out_strides[d] = out_strides[d + 1] * out_inner_shape[d + 1];
1754    }
1755
1756    // `pads` is ordered LAST padded axis first; the padded inner dims are the
1757    // trailing `npad` dims of the inner block. Inner-dim index for pad entry `k`
1758    // (which targets axis `dim = ninner - 1 - k`).
1759    let pad_for_inner_dim = |d: usize| -> (isize, isize) {
1760        // d in [ninner-npad, ninner-1] -> k = ninner - 1 - d
1761        pads[ninner - 1 - d]
1762    };
1763
1764    // Per-dim half-open copy windows for the dst (`out`) and src.
1765    // `copy_block` copies `src[src_win]` (broadcast) into `out[dst_win]`,
1766    // propagating the init mask, and returns Err on a broadcast-illegal `copy_`.
1767    // `dst_win`/`src_win` are `(start,end)` per inner dim.
1768    //
1769    // When `read_data` is `Some`, the source is a SEPARATE buffer (the original
1770    // input, for the center copy — `read_strides` indexes it). When `read_data`
1771    // is `None`, the source is `out` ITSELF, read LIVE in the same pass: this
1772    // mirrors torch's `:169-187` wrap `copy_`s where `in_slice = out.slice(...)`
1773    // aliases the very `out` buffer being written (`read_strides` indexes `out`).
1774    // Iterating in row-major dst order, a wrap cell can therefore read a cell the
1775    // center (or an earlier dst cell in this same wrap) just wrote, deterministi-
1776    // cally propagating a narrow center band exactly as torch does (#1629).
1777    #[allow(clippy::too_many_arguments)]
1778    fn copy_block<T: Float>(
1779        out: &mut [T],
1780        init: &mut [bool],
1781        read_data: Option<&[T]>,
1782        read_init: Option<&[bool]>,
1783        ninner: usize,
1784        out_strides: &[usize],
1785        read_strides: &[usize],
1786        dst_win: &[(usize, usize)],
1787        src_win: &[(usize, usize)],
1788    ) -> FerrotorchResult<()> {
1789        // Broadcast-legality (torch `copy_`): per dim, dst extent must equal src
1790        // extent OR src extent must be 1. Otherwise torch raises.
1791        let mut dst_ext = vec![0usize; ninner];
1792        let mut src_ext = vec![0usize; ninner];
1793        for d in 0..ninner {
1794            dst_ext[d] = dst_win[d].1 - dst_win[d].0;
1795            src_ext[d] = src_win[d].1 - src_win[d].0;
1796            if dst_ext[d] != src_ext[d] && src_ext[d] != 1 {
1797                return Err(FerrotorchError::InvalidArgument {
1798                    message: format!(
1799                        "Circular padding: a slice copy of source extent {} into destination extent {} is not broadcastable on inner dim {d} (torch raises a size-mismatch here)",
1800                        src_ext[d], dst_ext[d]
1801                    ),
1802                });
1803            }
1804        }
1805        let total: usize = dst_ext.iter().product();
1806        if total == 0 {
1807            return Ok(()); // no-op over an empty extent (`:148` empty out_shape)
1808        }
1809        // torch `copy_` memory-overlap gate (live-read wraps only). When the wrap
1810        // reads from `out` itself (`read_data is None`), torch's `copy_` raises
1811        // `RuntimeError: ... refer to a single memory location` when the source
1812        // and destination slices each form a CONTIGUOUS memory run AND those runs
1813        // overlap by a non-identity offset (MEM_OVERLAP_YES). A wrap slices a
1814        // SINGLE dim `wd` (all other dims full-extent); its dst/src each form a
1815        // contiguous run iff every dim MORE MAJOR than `wd` has extent 1 (else the
1816        // slice repeats once per major index → strided, and torch's overlap
1817        // detector returns "too hard" and proceeds with the well-defined band
1818        // propagation, #1629). An EXACT-identity window pair is a self-copy no-op
1819        // torch always allows; disjoint windows never overlap. We mirror torch's
1820        // raise as a clean `Err` (R-CODE-2: never a panic).
1821        if read_data.is_none() {
1822            let mut wrap_dim: Option<usize> = None;
1823            for d in 0..ninner {
1824                if dst_win[d] != src_win[d] {
1825                    // a wrap differs on exactly one (the wrap) dim
1826                    wrap_dim = Some(d);
1827                    break;
1828                }
1829            }
1830            if let Some(wd) = wrap_dim {
1831                // contiguous run iff every more-major dim is collapsed to extent 1
1832                let runs_contiguous = (0..wd).all(|d| dst_ext[d] == 1);
1833                let ds = dst_win[wd];
1834                let ss = src_win[wd];
1835                let overlap = ds.0 < ss.1 && ss.0 < ds.1; // half-open range overlap
1836                let identical = ds == ss;
1837                if runs_contiguous && overlap && !identical {
1838                    return Err(FerrotorchError::InvalidArgument {
1839                        message:
1840                            "Circular padding: torch's wrap copy_ would read and write a single memory location over a contiguous slice (RuntimeError: some elements of the input and written-to tensor refer to a single memory location); ferrotorch rejects rather than fabricate (R-DEV-6)"
1841                                .to_string(),
1842                    });
1843                }
1844            }
1845        }
1846        // Iterate every dst coordinate; map to the (broadcast) src coordinate.
1847        let mut coord = vec![0usize; ninner];
1848        for _ in 0..total {
1849            let mut dst_off = 0usize;
1850            let mut src_off = 0usize;
1851            for d in 0..ninner {
1852                let dc = dst_win[d].0 + coord[d];
1853                dst_off += dc * out_strides[d];
1854                let sc = if src_ext[d] == 1 {
1855                    src_win[d].0
1856                } else {
1857                    src_win[d].0 + coord[d]
1858                };
1859                src_off += sc * read_strides[d];
1860            }
1861            // LIVE read: when `read_data` is `None` the source IS `out`/`init`
1862            // (torch's wrap `in_slice = out.slice(...)`), so we read the current
1863            // value at `src_off` — including a cell written earlier in this very
1864            // pass — before overwriting `dst_off`.
1865            let (v, src_inited) = match (read_data, read_init) {
1866                (Some(rd), ri) => (rd[src_off], ri.map(|m| m[src_off]).unwrap_or(true)),
1867                (None, _) => (out[src_off], init[src_off]),
1868            };
1869            out[dst_off] = v;
1870            init[dst_off] = src_inited;
1871            // advance coord (row-major over dst extents)
1872            let mut d = ninner;
1873            while d > 0 {
1874                d -= 1;
1875                coord[d] += 1;
1876                if coord[d] < dst_ext[d] {
1877                    break;
1878                }
1879                coord[d] = 0;
1880            }
1881        }
1882        Ok(())
1883    }
1884
1885    // `:154-161` — the single CENTER copy. Build dst/src windows per inner dim.
1886    let mut dst_win = vec![(0usize, 0usize); ninner];
1887    let mut src_win = vec![(0usize, 0usize); ninner];
1888    for d in 0..ninner {
1889        let out_len = out_inner_shape[d] as isize;
1890        let in_len = in_inner_shape[d] as isize;
1891        if d < ninner - npad {
1892            // Non-padded inner dim: full extent on both sides.
1893            dst_win[d] = (0, out_inner_shape[d]);
1894            src_win[d] = (0, in_inner_shape[d]);
1895        } else {
1896            let (pl, pr) = pad_for_inner_dim(d);
1897            dst_win[d] = circular_slice_range(out_len, pl.max(0), out_len - pr.max(0));
1898            src_win[d] = circular_slice_range(in_len, (-pl).max(0), in_len - (-pr).max(0));
1899        }
1900    }
1901    copy_block(
1902        &mut out,
1903        &mut init,
1904        Some(in_block),
1905        None,
1906        ninner,
1907        &out_strides,
1908        &in_strides,
1909        &dst_win,
1910        &src_win,
1911    )?;
1912
1913    // `:169-187` — the left/right wrap copies, each reading from `out` LIVE.
1914    // torch's `in_slice = out.slice_symint(...)` (`:176`/`:184`) aliases the SAME
1915    // `out` buffer the loop is writing, and `:163-165` is explicit that corners
1916    // are written more than once across the sequence. So each wrap reads the
1917    // CURRENT `out` (including cells the center or an earlier wrap just wrote),
1918    // deterministically propagating a narrow over-cropped center band exactly as
1919    // torch does (#1629). We pass `read_data = None` so `copy_block` reads `out`/
1920    // `init` in place — NOT a pre-copy snapshot. Cells torch never writes stay
1921    // uninit and are caught by the leftover-uninit R-DEV-6 check below.
1922    for (k, &(pl, pr)) in pads.iter().enumerate() {
1923        // i in torch is k counted from the FIRST padded axis; torch's `dim` is
1924        // the inner dim `ninner - npad + k`. Our `pads` is last-axis-first, so
1925        // entry k targets inner dim `ninner - 1 - k`. torch iterates i=0..npad
1926        // over `pad[2*i]` (first-axis-first); the set of (dim,pl,pr) visited is
1927        // identical, only the order differs — and torch's wraps on distinct dims
1928        // are order-independent for the WELL-DEFINED cases (the order-dependent
1929        // overlapping ones land in the R-DEV-6 leftover-uninit reject either way).
1930        let dim = ninner - 1 - k;
1931        let out_len = out_inner_shape[dim] as isize;
1932        if pl > 0 {
1933            let mut dwin = vec![(0usize, 0usize); ninner];
1934            let mut swin = vec![(0usize, 0usize); ninner];
1935            for d in 0..ninner {
1936                dwin[d] = (0, out_inner_shape[d]);
1937                swin[d] = (0, out_inner_shape[d]);
1938            }
1939            dwin[dim] = circular_slice_range(out_len, 0, pl);
1940            swin[dim] =
1941                circular_slice_range(out_len, out_len - pl - pr.max(0), out_len - pr.max(0));
1942            copy_block(
1943                &mut out,
1944                &mut init,
1945                None,
1946                None,
1947                ninner,
1948                &out_strides,
1949                &out_strides,
1950                &dwin,
1951                &swin,
1952            )?;
1953        }
1954        if pr > 0 {
1955            let mut dwin = vec![(0usize, 0usize); ninner];
1956            let mut swin = vec![(0usize, 0usize); ninner];
1957            for d in 0..ninner {
1958                dwin[d] = (0, out_inner_shape[d]);
1959                swin[d] = (0, out_inner_shape[d]);
1960            }
1961            dwin[dim] = circular_slice_range(out_len, out_len - pr, out_len);
1962            swin[dim] = circular_slice_range(out_len, pl.max(0), pl.max(0) + pr);
1963            copy_block(
1964                &mut out,
1965                &mut init,
1966                None,
1967                None,
1968                ninner,
1969                &out_strides,
1970                &out_strides,
1971                &dwin,
1972                &swin,
1973            )?;
1974        }
1975    }
1976
1977    // R-DEV-6: if any output cell is still uninitialized, torch read freed /
1978    // uninitialized memory there (a mixed-sign over-crop where the cropped
1979    // center is narrower than the wrap, or an overlapping `copy_`). There is no
1980    // reproducible byte-for-byte contract, so ferrotorch rejects cleanly rather
1981    // than emit nondeterministic garbage (R-CODE-2: no panic).
1982    if init.iter().any(|&b| !b) {
1983        return Err(FerrotorchError::InvalidArgument {
1984            message:
1985                "Circular padding crops the center below the wrap width, so torch reads uninitialized memory (no byte-for-byte contract; R-DEV-6)"
1986                    .to_string(),
1987        });
1988    }
1989    Ok(out)
1990}
1991
1992/// One `copy_` operation in torch's circular forward sequence, recorded as the
1993/// list of `(dst_offset, src_offset)` cell pairs it touches over the inner
1994/// (padded-region) flat buffer. `from_input` distinguishes the center copy
1995/// (source is the ORIGINAL input buffer, `PadNd.cpp:154-161`) from a wrap copy
1996/// (source is the LIVE `out` buffer, `:169-187`) — the two backprop into
1997/// different grad buffers.
1998struct CircularCopyOp {
1999    from_input: bool,
2000    pairs: Vec<(usize, usize)>,
2001}
2002
2003/// BACKWARD of torch's circular slice-copy forward, computed as the exact
2004/// autograd TRANSPOSE of the forward `copy_` sequence over ONE outer batch
2005/// block. `_pad_circular` (`PadNd.cpp:148-187`) is a differentiable composition
2006/// of `new_empty` + `slice` + `copy_`, so torch autograd differentiates it
2007/// directly: there is no hand-written backward. Each `out_slice.copy_(in_slice)`
2008/// (`:161`, `:179`, `:185`) backprops as `grad_src += grad_dst` over its copied
2009/// cells AND then ZEROS `grad_dst` (a `copy_` OVERWRITES the destination, so the
2010/// dst's pre-copy value did not flow forward). Processing the recorded ops in
2011/// REVERSE order with this accumulate-then-zero rule reproduces torch's grad
2012/// byte-for-byte (R-DEV-1) — including the over-crop wrap-propagation cases
2013/// (#1629/#1631) where the OLD per-axis `circular_axis_src` gather returned an
2014/// out-of-range source index and PANICKED (#1631, R-CODE-2). A center copy's
2015/// source is the ORIGINAL input (its grad accumulates into `grad_in`); a wrap's
2016/// source is a LIVE `out` cell (its grad accumulates back into the working
2017/// `grad_out`). For an empty / net-zero output every op has zero pairs ⇒ zero
2018/// grad contribution (no OOB). `go_block` is this block's output grad; the
2019/// returned vector is this block's input grad.
2020fn circular_slicecopy_backward_block<T: Float>(
2021    go_block: &[T],
2022    in_inner_shape: &[usize],
2023    out_inner_shape: &[usize],
2024    pads: &[(isize, isize)],
2025) -> Vec<T> {
2026    let npad = pads.len();
2027    let ninner = in_inner_shape.len();
2028    let in_total: usize = in_inner_shape.iter().product();
2029
2030    let mut in_strides = vec![1usize; ninner];
2031    let mut out_strides = vec![1usize; ninner];
2032    for d in (0..ninner.saturating_sub(1)).rev() {
2033        in_strides[d] = in_strides[d + 1] * in_inner_shape[d + 1];
2034        out_strides[d] = out_strides[d + 1] * out_inner_shape[d + 1];
2035    }
2036
2037    let pad_for_inner_dim = |d: usize| -> (isize, isize) { pads[ninner - 1 - d] };
2038
2039    // Enumerate the `(dst_off, src_off)` pairs of one `copy_`, mirroring
2040    // `circular_slicecopy_block`'s `copy_block` iteration 1:1 (same windows,
2041    // same row-major dst order, same broadcast rule). `src_strides` indexes the
2042    // input for the center copy and `out` for a live wrap.
2043    let enum_pairs = |dst_win: &[(usize, usize)],
2044                      src_win: &[(usize, usize)],
2045                      src_strides: &[usize]|
2046     -> Vec<(usize, usize)> {
2047        let mut dst_ext = vec![0usize; ninner];
2048        let mut src_ext = vec![0usize; ninner];
2049        for d in 0..ninner {
2050            dst_ext[d] = dst_win[d].1 - dst_win[d].0;
2051            src_ext[d] = src_win[d].1 - src_win[d].0;
2052        }
2053        let total: usize = dst_ext.iter().product();
2054        let mut pairs = Vec::with_capacity(total);
2055        if total == 0 {
2056            return pairs;
2057        }
2058        let mut coord = vec![0usize; ninner];
2059        for _ in 0..total {
2060            let mut dst_off = 0usize;
2061            let mut src_off = 0usize;
2062            for d in 0..ninner {
2063                dst_off += (dst_win[d].0 + coord[d]) * out_strides[d];
2064                let sc = if src_ext[d] == 1 {
2065                    src_win[d].0
2066                } else {
2067                    src_win[d].0 + coord[d]
2068                };
2069                src_off += sc * src_strides[d];
2070            }
2071            pairs.push((dst_off, src_off));
2072            let mut d = ninner;
2073            while d > 0 {
2074                d -= 1;
2075                coord[d] += 1;
2076                if coord[d] < dst_ext[d] {
2077                    break;
2078                }
2079                coord[d] = 0;
2080            }
2081        }
2082        pairs
2083    };
2084
2085    let mut ops: Vec<CircularCopyOp> = Vec::new();
2086
2087    // `:154-161` — the single CENTER copy (source = original input window).
2088    let mut dst_win = vec![(0usize, 0usize); ninner];
2089    let mut src_win = vec![(0usize, 0usize); ninner];
2090    for d in 0..ninner {
2091        let out_len = out_inner_shape[d] as isize;
2092        let in_len = in_inner_shape[d] as isize;
2093        if d < ninner - npad {
2094            dst_win[d] = (0, out_inner_shape[d]);
2095            src_win[d] = (0, in_inner_shape[d]);
2096        } else {
2097            let (pl, pr) = pad_for_inner_dim(d);
2098            dst_win[d] = circular_slice_range(out_len, pl.max(0), out_len - pr.max(0));
2099            src_win[d] = circular_slice_range(in_len, (-pl).max(0), in_len - (-pr).max(0));
2100        }
2101    }
2102    ops.push(CircularCopyOp {
2103        from_input: true,
2104        pairs: enum_pairs(&dst_win, &src_win, &in_strides),
2105    });
2106
2107    // `:169-187` — the left/right wrap copies (source = LIVE `out`), recorded in
2108    // the SAME order as the forward.
2109    for (k, &(pl, pr)) in pads.iter().enumerate() {
2110        let dim = ninner - 1 - k;
2111        let out_len = out_inner_shape[dim] as isize;
2112        if pl > 0 {
2113            let mut dwin = vec![(0usize, 0usize); ninner];
2114            let mut swin = vec![(0usize, 0usize); ninner];
2115            for d in 0..ninner {
2116                dwin[d] = (0, out_inner_shape[d]);
2117                swin[d] = (0, out_inner_shape[d]);
2118            }
2119            dwin[dim] = circular_slice_range(out_len, 0, pl);
2120            swin[dim] =
2121                circular_slice_range(out_len, out_len - pl - pr.max(0), out_len - pr.max(0));
2122            ops.push(CircularCopyOp {
2123                from_input: false,
2124                pairs: enum_pairs(&dwin, &swin, &out_strides),
2125            });
2126        }
2127        if pr > 0 {
2128            let mut dwin = vec![(0usize, 0usize); ninner];
2129            let mut swin = vec![(0usize, 0usize); ninner];
2130            for d in 0..ninner {
2131                dwin[d] = (0, out_inner_shape[d]);
2132                swin[d] = (0, out_inner_shape[d]);
2133            }
2134            dwin[dim] = circular_slice_range(out_len, out_len - pr, out_len);
2135            swin[dim] = circular_slice_range(out_len, pl.max(0), pl.max(0) + pr);
2136            ops.push(CircularCopyOp {
2137                from_input: false,
2138                pairs: enum_pairs(&dwin, &swin, &out_strides),
2139            });
2140        }
2141    }
2142
2143    // Reverse transpose: working `grad_out` starts as the incoming output grad;
2144    // `grad_in` accumulates the input grad. For each op (reverse order) add
2145    // `grad_out[dst]` into its source, THEN zero `grad_out[dst]` (the `copy_`
2146    // overwrote `dst`, so its pre-copy value contributed nothing). A wrap's
2147    // source is a live `out` cell ⇒ accumulate back into `grad_out`; the center
2148    // copy's source is an input cell ⇒ accumulate into `grad_in`. We accumulate
2149    // every contribution for the op BEFORE zeroing so distinct dst cells reading
2150    // distinct (or broadcast-shared) sources all land correctly.
2151    let zero = <T as num_traits::Zero>::zero();
2152    let mut grad_out = go_block.to_vec();
2153    let mut grad_in = vec![zero; in_total];
2154    for op in ops.iter().rev() {
2155        if op.from_input {
2156            for &(d, s) in &op.pairs {
2157                grad_in[s] += grad_out[d];
2158                grad_out[d] = zero;
2159            }
2160        } else {
2161            // Accumulate into a scratch keyed by source `out` cell first, then
2162            // zero the dst cells, then fold the scratch back into `grad_out`.
2163            // (A dst cell may also be a source cell of another pair in the SAME
2164            // op only for an identity self-copy, which the forward overlap gate
2165            // rejects; for legal wraps dst and src windows are disjoint, so the
2166            // ordering is moot — but the scratch keeps it correct regardless.)
2167            let mut contrib: Vec<(usize, T)> = Vec::with_capacity(op.pairs.len());
2168            for &(d, s) in &op.pairs {
2169                contrib.push((s, grad_out[d]));
2170            }
2171            for &(d, _) in &op.pairs {
2172                grad_out[d] = zero;
2173            }
2174            for (s, v) in contrib {
2175                grad_out[s] += v;
2176            }
2177        }
2178    }
2179    grad_in
2180}
2181
2182/// Resolve, for one axis, the source index a reflect/circular output index
2183/// reads from the ORIGINAL input window. Both modes always read a real element
2184/// (never a fill), so this returns a bare `usize`. `(lo, hi)` are the signed
2185/// pads on this axis. The circular index is pre-validated in
2186/// `circular_axis_new_size` to lie in `0..size`, so the `as usize` cast here is
2187/// always in-bounds (no OOB — R-CODE-2).
2188#[inline]
2189fn signed_mode_axis_src(mode: PaddingMode, j: usize, size: usize, lo: isize, hi: isize) -> usize {
2190    match mode {
2191        PaddingMode::Reflect => reflect_axis_src(j, size, lo),
2192        PaddingMode::Replicate => replicate_axis_src(j, size, lo),
2193        PaddingMode::Circular => circular_axis_src(j, size, lo, hi) as usize,
2194        // Zeros routes through the constant gather; this resolver is only invoked
2195        // for Reflect/Replicate/Circular (see `pad_nd_signed_reflect_circular` /
2196        // `PadNdSignedModeBackward`); the clamp here is a defensive in-bounds
2197        // fallback that never executes.
2198        PaddingMode::Zeros => (j as isize - lo).clamp(0, size as isize - 1) as usize,
2199    }
2200}
2201
2202/// Crop-capable reflect/replicate/circular pad over the last `npad` dimensions
2203/// using the unified index map against the ORIGINAL input window. `pads` is
2204/// `[(lo,hi), ...]` ordered LAST padded axis first. Output extent per axis is
2205/// `size + lo + hi` (negative pads narrow). Reflect legality (SIGNED `lo < size`
2206/// and `hi < size` per axis, checked against the ORIGINAL size, mirroring
2207/// `aten/src/ATen/native/ReflectionPad.cpp:48-49`) is validated here. Reflect &
2208/// replicate use a RANK-DEPENDENT net-zero rule: 1-D requires output `>= 1`
2209/// while 2-D/3-D allow a per-axis net-zero (empty `[..,0,..]`) so long as one
2210/// padded axis survives (`ReflectionPad.cpp:251`/`:152`,
2211/// `ReplicationPadding.cpp:114`). Replicate gathers with the boundary clamp of
2212/// `ReplicationPad::index` (`cpu/PaddingKernel.cpp:84-95`), so an over-crop to a
2213/// zero-size axis never underflows (#1625).
2214fn pad_nd_signed_reflect_circular<T: Float>(
2215    data: &[T],
2216    shape: &[usize],
2217    pads: &[(isize, isize)],
2218    mode: PaddingMode,
2219) -> FerrotorchResult<(Vec<T>, Vec<usize>)> {
2220    let ndim = shape.len();
2221    let npad = pads.len();
2222    let mut new_shape = shape.to_vec();
2223    // Reflect's net-zero output rule is RANK-DEPENDENT (matches torch's per-rank
2224    // meta functions): 1-D `reflection_pad1d` requires `output_w >= 1`
2225    // (`aten/src/ATen/native/ReflectionPad.cpp:60-65`) so a net-zero axis Errs,
2226    // but 2-D `reflection_pad2d` (`:251`) and 3-D `reflection_pad3d` (`:152`)
2227    // require only `output_w >= 1 || output_h >= 1 (|| output_d >= 1)`, allowing
2228    // an INDIVIDUAL spatial axis to be net-zero (an empty `[..,0,..]` tensor) as
2229    // long as at least one spatial axis survives. Replicate has the identical
2230    // rank-dependent shape: `replication_pad1d` requires `owidth >= 1`
2231    // (`ReplicationPadding.cpp:49`) while `replication_pad2d`/`3d` use the same
2232    // OR (`:114`). So per-axis we reject a net-zero ONLY when `npad == 1` (the
2233    // 1-D kernel); for `npad >= 2` a single axis may be 0, and a final guard
2234    // below enforces that not ALL spatial axes are 0. (#1626)
2235    let per_axis_min: isize = isize::from(npad == 1);
2236    for (k, &(lo, hi)) in pads.iter().enumerate() {
2237        let dim = ndim - 1 - k;
2238        let size = shape[dim] as isize;
2239        // Reflect: torch's check is SIGNED, not absolute
2240        // (`aten/src/ATen/native/ReflectionPad.cpp:48-49`):
2241        // `TORCH_CHECK(pad_l < input_w && pad_r < input_w, ...)`. A NEGATIVE
2242        // (crop) pad is always `< input_w`, so torch only rejects POSITIVE pads
2243        // whose magnitude reaches `>= input_w`. Replicate has NO such
2244        // `pad < input` check upstream (`ReplicationPadding.cpp` only guards the
2245        // output extent), so this rejection is reflect-only.
2246        if mode == PaddingMode::Reflect && (lo >= size || hi >= size) {
2247            return Err(FerrotorchError::InvalidArgument {
2248                message: format!(
2249                    "Reflection padding ({lo}, {hi}) must be less than input size ({size}) on dimension {dim}"
2250                ),
2251            });
2252        }
2253        // Replicate requires a non-empty ORIGINAL axis (the clamp gathers a real
2254        // boundary element). torch's `check_valid_input` rejects a zero-size
2255        // input plane, so size 0 here is impossible for a valid call; guard
2256        // defensively to keep the clamp index in `0..size`.
2257        if mode == PaddingMode::Replicate && size == 0 {
2258            return Err(FerrotorchError::InvalidArgument {
2259                message: format!(
2260                    "Replication padding cannot replicate an empty input dimension {dim} (size 0)"
2261                ),
2262            });
2263        }
2264        // Circular: torch's `_pad_circular_symint` is allocate-then-copy
2265        // (`aten/src/ATen/native/PadNd.cpp:140-187`). The PER-AXIS legality is
2266        // ONLY `:142` (reject `pad > size`, wraps more than once) and `:144`
2267        // (reject a negative net extent; allow exactly `0` → an empty dim) —
2268        // `circular_axis_legality`. The center copy (`:158-161`) and the wrap
2269        // gather (`:169-187`) operate on slices of the FULL `:148 new_empty`
2270        // output, so they are validated SEPARATELY below, gated on the WHOLE
2271        // output being non-empty (any `out_i == 0` ⇒ every `copy_` no-ops ⇒
2272        // torch returns the empty tensor without materializing ANY wrap index,
2273        // #1628). Reflect/Replicate use the rank-dependent `per_axis_min` reject.
2274        let new_size: usize = if mode == PaddingMode::Circular {
2275            circular_axis_legality(shape[dim], lo, hi, dim)?
2276        } else {
2277            let n = size + lo + hi;
2278            if n < per_axis_min {
2279                return Err(FerrotorchError::InvalidArgument {
2280                    message: format!(
2281                        "padding ({lo}, {hi}) on dimension {dim} of size {size} yields output size {n} below the minimum {per_axis_min} for this rank"
2282                    ),
2283                });
2284            }
2285            n as usize
2286        };
2287        new_shape[dim] = new_size;
2288    }
2289
2290    // 2-D/3-D reflect & replicate: at least one padded spatial axis must survive
2291    // (`output_w >= 1 || output_h >= 1 (|| output_d >= 1)`,
2292    // `ReflectionPad.cpp:251`/`:152`, `ReplicationPadding.cpp:114`). When every
2293    // padded axis collapsed to 0, torch Errs "input is too small".
2294    if npad >= 2
2295        && matches!(mode, PaddingMode::Reflect | PaddingMode::Replicate)
2296        && pads
2297            .iter()
2298            .enumerate()
2299            .all(|(k, _)| new_shape[ndim - 1 - k] == 0)
2300    {
2301        return Err(FerrotorchError::InvalidArgument {
2302            message: format!(
2303                "{mode:?} padding collapses every padded spatial axis to size 0 (torch requires at least one >= 1)"
2304            ),
2305        });
2306    }
2307
2308    let first_padded = ndim - npad;
2309    let outer: usize = shape[..first_padded]
2310        .iter()
2311        .copied()
2312        .product::<usize>()
2313        .max(1);
2314    let in_inner: usize = shape[first_padded..].iter().product();
2315    let out_inner: usize = new_shape[first_padded..].iter().product();
2316    let zero = <T as num_traits::Zero>::zero();
2317    let new_total: usize = new_shape.iter().copied().product();
2318    let mut out = vec![zero; new_total];
2319
2320    // CIRCULAR: HOLISTIC allocate-then-copy (`PadNd.cpp:148-187`) per outer
2321    // batch, mirroring torch's `:148 new_empty(out_shape)` + center/wrap `copy_`
2322    // sequence on the FULL N-D output. This replaces the prior per-axis wrap-OOB
2323    // pre-validation + per-axis gather, which rejected an axis whose ISOLATED
2324    // wrap was OOB even when a SIBLING axis had already emptied the whole output
2325    // (the #1628 cross-axis net-zero divergence). The simulator reproduces the
2326    // empty short-circuit (any `out_i == 0` ⇒ every `copy_` no-ops ⇒ empty
2327    // tensor), the cross-axis well-defined wraps, AND the R-DEV-6 over-crop
2328    // rejection (leftover-uninit ⇒ Err, never a panic) in one faithful pass.
2329    if mode == PaddingMode::Circular {
2330        let in_inner_shape = &shape[first_padded..];
2331        let out_inner_shape = &new_shape[first_padded..];
2332        for o in 0..outer {
2333            let in_block = &data[o * in_inner..(o + 1) * in_inner];
2334            let out_block =
2335                circular_slicecopy_block(in_block, in_inner_shape, out_inner_shape, pads)?;
2336            out[o * out_inner..(o + 1) * out_inner].copy_from_slice(&out_block);
2337        }
2338        return Ok((out, new_shape));
2339    }
2340
2341    // REFLECT / REPLICATE: the unified original-window per-axis gather
2342    // (`cpu/PaddingKernel.cpp:63-105`). Each output index reads a real input
2343    // element via the mode's per-axis index resolver.
2344    for o in 0..outer {
2345        let in_base = o * in_inner;
2346        let out_base = o * out_inner;
2347        for flat in 0..out_inner {
2348            let mut rem = flat;
2349            let mut src_lin = 0usize;
2350            let mut src_stride = 1usize;
2351            for k in 0..npad {
2352                let dim = ndim - 1 - k;
2353                let axis_new = new_shape[dim];
2354                let coord = rem % axis_new;
2355                rem /= axis_new;
2356                let (lo, hi) = pads[k];
2357                let s = signed_mode_axis_src(mode, coord, shape[dim], lo, hi);
2358                src_lin += s * src_stride;
2359                src_stride *= shape[dim];
2360            }
2361            out[out_base + flat] = data[in_base + src_lin];
2362        }
2363    }
2364
2365    Ok((out, new_shape))
2366}
2367
2368/// Backward for the signed reflect/circular pad: the adjoint of the unified
2369/// gather is a scatter-add into the original-size input
2370/// (`grad_input[src(o)] += grad_output[o]`), matching torch's
2371/// `reflection_pad*_backward` / `_pad_circular` backward.
2372#[derive(Debug)]
2373struct PadNdSignedModeBackward<T: Float> {
2374    input: Tensor<T>,
2375    input_shape: Vec<usize>,
2376    mode: PaddingMode,
2377    /// `(lo, hi)` per padded axis, ordered LAST axis first (same as the forward).
2378    pads: Vec<(isize, isize)>,
2379}
2380
2381impl<T: Float> GradFn<T> for PadNdSignedModeBackward<T> {
2382    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
2383        if !self.input.requires_grad() {
2384            return Ok(vec![None]);
2385        }
2386        let ndim = self.input_shape.len();
2387        let npad = self.pads.len();
2388        let first_padded = ndim - npad;
2389        let outer: usize = self.input_shape[..first_padded]
2390            .iter()
2391            .copied()
2392            .product::<usize>()
2393            .max(1);
2394        let in_inner: usize = self.input_shape[first_padded..].iter().product();
2395
2396        let go_shape = grad_output.shape();
2397        let out_inner: usize = go_shape[first_padded..].iter().product();
2398
2399        let go = grad_output.data_vec()?;
2400        let zero = <T as num_traits::Zero>::zero();
2401        let mut grad_in = vec![zero; outer * in_inner];
2402
2403        if self.mode == PaddingMode::Circular {
2404            // CIRCULAR backward is the scatter-add TRANSPOSE of the #1629
2405            // holistic forward `circular_slicecopy_block` (live-wrap slice-copy),
2406            // NOT the old per-axis `circular_axis_src` gather (which returned an
2407            // out-of-range source index for an over-cropped axis → index OOB
2408            // panic, #1631). We replay the SAME forward output→source mapping via
2409            // `circular_slicecopy_src_map` (center copy + live wraps), then for
2410            // each forward write `out[o] = in[src_map[o]]` scatter-add
2411            // `grad_out[o]` into `grad_in[src_map[o]]`. Cells the forward read
2412            // more than once accumulate their grads, matching torch's
2413            // `_pad_circular` backward (the transpose of `PadNd.cpp:176-179`
2414            // `out_slice.copy_(in_slice)` aliasing reads). Over-cropped /
2415            // net-zero-empty outputs produce zero forward writes ⇒ zero grad
2416            // contribution (no OOB, no panic — R-CODE-2).
2417            let in_inner_shape = &self.input_shape[first_padded..];
2418            let out_inner_shape = &go_shape[first_padded..];
2419            for o in 0..outer {
2420                let in_base = o * in_inner;
2421                let out_base = o * out_inner;
2422                let go_block = &go[out_base..out_base + out_inner];
2423                let gi_block = circular_slicecopy_backward_block(
2424                    go_block,
2425                    in_inner_shape,
2426                    out_inner_shape,
2427                    &self.pads,
2428                );
2429                for (i, &v) in gi_block.iter().enumerate() {
2430                    grad_in[in_base + i] += v;
2431                }
2432            }
2433            let grad_input =
2434                Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
2435            return Ok(vec![Some(grad_input)]);
2436        }
2437
2438        for o in 0..outer {
2439            let in_base = o * in_inner;
2440            let out_base = o * out_inner;
2441            for flat in 0..out_inner {
2442                let mut rem = flat;
2443                let mut src_lin = 0usize;
2444                let mut src_stride = 1usize;
2445                for k in 0..npad {
2446                    let dim = ndim - 1 - k;
2447                    let axis_new = go_shape[dim];
2448                    let coord = rem % axis_new;
2449                    rem /= axis_new;
2450                    let (lo, hi) = self.pads[k];
2451                    let s = signed_mode_axis_src(self.mode, coord, self.input_shape[dim], lo, hi);
2452                    src_lin += s * src_stride;
2453                    src_stride *= self.input_shape[dim];
2454                }
2455                grad_in[in_base + src_lin] += go[out_base + flat];
2456            }
2457        }
2458
2459        let grad_input =
2460            Tensor::from_storage(TensorStorage::cpu(grad_in), self.input_shape.clone(), false)?;
2461        Ok(vec![Some(grad_input)])
2462    }
2463
2464    fn inputs(&self) -> Vec<&Tensor<T>> {
2465        vec![&self.input]
2466    }
2467
2468    fn name(&self) -> &'static str {
2469        "PadNdSignedModeBackward"
2470    }
2471}
2472
2473/// Shared signed-pad driver for the 1-D/2-D/3-D public entrypoints. `pads` is
2474/// ordered LAST padded axis first.
2475///
2476/// For `PaddingMode::Zeros` (torch `mode="constant"`) negative pads narrow via
2477/// the signed-constant gather below. For reflect/replicate/circular, live torch
2478/// 2.11 does NOT reject a negative pad — `_pad_enum` dispatches straight to the
2479/// native `reflection_pad*` / `replication_pad*` / `_pad_circular` kernels,
2480/// which compute `output = input + pad_l + pad_r` directly (a negative pad
2481/// narrows the side) and offset the gather window by `max(0,-pad) - max(0,pad)`
2482/// (`aten/src/ATen/native/ReflectionPad.cpp:46`,
2483/// `aten/src/ATen/native/cpu/PaddingKernel.cpp:63-65`,
2484/// `aten/src/ATen/native/PadNd.cpp:158-159`). That is byte-identical to first
2485/// CROPPING the negative side(s) (constant-mode narrow) and then applying the
2486/// positive pad part with the mode's gather — verified against the live oracle
2487/// (`reflect [-1,0]` on `[1,2,3,4,5]` -> `[2,3,4,5]`; `replicate [1,-1]` ->
2488/// `[1,1,2,3,4]` grad `[2,1,1,1,0]`; `circular [-1,0]` -> `[2,3,4,5]` grad
2489/// `[0,1,1,1,1]`; `reflect2d [-1,1,0,0]` on the 3x3 -> `[[2,3,2],[5,6,5],
2490/// [8,9,8]]`). We compose crop-then-pad so the backward chains the crop adjoint
2491/// (zero-pad of the cropped side) with the mode-pad adjoint (the gather
2492/// scatter-add) through the normal autograd graph. Over-cropping a side
2493/// (`crop >= dim`) still errors via the signed-constant `narrow` check, matching
2494/// torch (`PadNd.cpp:221-242`).
2495fn functional_pad_nd_signed<T: Float>(
2496    input: &Tensor<T>,
2497    pads: &[(isize, isize)],
2498    mode: PaddingMode,
2499    value: T,
2500) -> FerrotorchResult<Tensor<T>> {
2501    let has_negative = pads.iter().any(|&(lo, hi)| lo < 0 || hi < 0);
2502
2503    if mode != PaddingMode::Zeros {
2504        if !has_negative {
2505            // All-non-negative under a non-constant mode: pure mode-pad.
2506            return functional_pad_nd_positive(input, pads, mode, value);
2507        }
2508        // Reflect/Replicate/Circular with a negative (crop) pad: torch does NOT
2509        // crop first. It reflects/clamps/wraps against the ORIGINAL input window
2510        // via a single index map with offset `max(0,-pad) - max(0,pad)`
2511        // (`aten/src/ATen/native/cpu/PaddingKernel.cpp:63-95`,
2512        // `ReflectionPad.cpp:46-48`, `PadNd.cpp:158-159`). A positive pad on a
2513        // cropped side reads elements a crop-first pass would have discarded
2514        // (e.g. `reflect [-3,2]` on `[1,2,3,4]` -> `[4,3,2]`, not an error).
2515        //
2516        // Replicate in particular MUST use the original-window clamp rather than
2517        // crop-then-pad: when a crop reduces an axis to size 0, the crop-first
2518        // path fed a zero-size axis to `pad_*_replicate`, which computed
2519        // `inner - 1` / `h - 1` and PANICKED (subtract-overflow). torch's
2520        // `ReplicationPad::index` (`PaddingKernel.cpp:84-95`) clamps the gather
2521        // to `[pad, size+pad-1]` against the ORIGINAL window, so an over-crop
2522        // still reads the preserved boundary element — no underflow, no panic
2523        // (#1625, R-CODE-2). We gather directly from the original window and
2524        // scatter-add the adjoint through `PadNdSignedModeBackward` (#1620 #1621
2525        // #1625).
2526        let data = input.data_vec()?;
2527        let shape = input.shape();
2528        if pads.len() > shape.len() {
2529            return Err(FerrotorchError::InvalidArgument {
2530                message: format!(
2531                    "pad targets {} dims but input has only {} dims",
2532                    pads.len(),
2533                    shape.len()
2534                ),
2535            });
2536        }
2537        let input_shape = shape.to_vec();
2538        let (out_data, new_shape) = pad_nd_signed_reflect_circular(&data, shape, pads, mode)?;
2539        if is_grad_enabled() && input.requires_grad() {
2540            let grad_fn = Arc::new(PadNdSignedModeBackward {
2541                input: input.clone(),
2542                input_shape,
2543                mode,
2544                pads: pads.to_vec(),
2545            });
2546            return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
2547        }
2548        return Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false);
2549    }
2550
2551    let data = input.data_vec()?;
2552    let shape = input.shape();
2553    if pads.len() > shape.len() {
2554        return Err(FerrotorchError::InvalidArgument {
2555            message: format!(
2556                "pad targets {} dims but input has only {} dims",
2557                pads.len(),
2558                shape.len()
2559            ),
2560        });
2561    }
2562    let input_shape = shape.to_vec();
2563    let (out_data, new_shape) = pad_nd_signed_constant(&data, shape, pads, value)?;
2564
2565    // Grad path: attach PadNdSignedBackward so autograd stays connected (same
2566    // #1550 bug class the positive-only paths fixed).
2567    if is_grad_enabled() && input.requires_grad() {
2568        let grad_fn = Arc::new(PadNdSignedBackward {
2569            input: input.clone(),
2570            input_shape,
2571            pads: pads.to_vec(),
2572        });
2573        return Tensor::from_operation(TensorStorage::cpu(out_data), new_shape, grad_fn);
2574    }
2575
2576    Tensor::from_storage(TensorStorage::cpu(out_data), new_shape, false)
2577}
2578
2579/// Apply crop-capable padding to the last dimension of a tensor. Unlike
2580/// [`functional_pad_1d`] (which takes `usize`), the pad amounts are SIGNED: a
2581/// negative value crops `|pad|` elements off that side, mirroring
2582/// `torch.nn.functional.pad(input, [left, right], mode="constant", value=...)`
2583/// with negative `left`/`right` (`aten/src/ATen/native/PadNd.cpp:29-108`).
2584///
2585/// Negative (crop) pads are supported under EVERY mode: `Zeros` narrows via the
2586/// signed-constant gather, while reflect/replicate/circular crop the negative
2587/// side(s) then apply their gather on the positive part — byte-identical to
2588/// torch's native kernels, which compute `output = input + pad_l + pad_r`
2589/// directly (`PadNd.cpp:221-242`). Over-cropping (removing more than the
2590/// dimension holds) returns `InvalidArgument`, mirroring torch's
2591/// "narrow(): length must be non-negative".
2592pub fn functional_pad_1d_signed<T: Float>(
2593    input: &Tensor<T>,
2594    pad_left: isize,
2595    pad_right: isize,
2596    mode: PaddingMode,
2597    value: T,
2598) -> FerrotorchResult<Tensor<T>> {
2599    functional_pad_nd_signed(input, &[(pad_left, pad_right)], mode, value)
2600}
2601
2602/// Crop-capable padding for the last 2 dimensions. Signed analogue of
2603/// [`functional_pad_2d`]; see [`functional_pad_1d_signed`] for the crop
2604/// semantics and constant-mode restriction.
2605pub fn functional_pad_2d_signed<T: Float>(
2606    input: &Tensor<T>,
2607    pad_left: isize,
2608    pad_right: isize,
2609    pad_top: isize,
2610    pad_bottom: isize,
2611    mode: PaddingMode,
2612    value: T,
2613) -> FerrotorchResult<Tensor<T>> {
2614    // `pads` is LAST axis (W: left/right) first, then 2nd-last (H: top/bottom).
2615    functional_pad_nd_signed(
2616        input,
2617        &[(pad_left, pad_right), (pad_top, pad_bottom)],
2618        mode,
2619        value,
2620    )
2621}
2622
2623/// Crop-capable padding for the last 3 dimensions. Signed analogue of
2624/// [`functional_pad_3d`]; see [`functional_pad_1d_signed`] for the crop
2625/// semantics and constant-mode restriction.
2626// Public API: matches `torch.nn.functional.pad`'s 3-axis layout
2627// (left, right, top, bottom, front, back) — 6 signed pad amounts.
2628#[allow(clippy::too_many_arguments)]
2629pub fn functional_pad_3d_signed<T: Float>(
2630    input: &Tensor<T>,
2631    pad_left: isize,
2632    pad_right: isize,
2633    pad_top: isize,
2634    pad_bottom: isize,
2635    pad_front: isize,
2636    pad_back: isize,
2637    mode: PaddingMode,
2638    value: T,
2639) -> FerrotorchResult<Tensor<T>> {
2640    // LAST axis (W) first, then H, then D (front/back).
2641    functional_pad_nd_signed(
2642        input,
2643        &[
2644            (pad_left, pad_right),
2645            (pad_top, pad_bottom),
2646            (pad_front, pad_back),
2647        ],
2648        mode,
2649        value,
2650    )
2651}
2652
2653// ===========================================================================
2654// Macro to reduce boilerplate for Module implementations on padding layers
2655// ===========================================================================
2656
2657macro_rules! impl_padding_module {
2658    ($name:ident) => {
2659        impl<T: Float> Module<T> for $name<T> {
2660            fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2661                self.pad(input)
2662            }
2663
2664            fn parameters(&self) -> Vec<&Parameter<T>> {
2665                vec![]
2666            }
2667
2668            fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
2669                vec![]
2670            }
2671
2672            fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
2673                vec![]
2674            }
2675
2676            fn train(&mut self) {
2677                self.training = true;
2678            }
2679
2680            fn eval(&mut self) {
2681                self.training = false;
2682            }
2683
2684            fn is_training(&self) -> bool {
2685                self.training
2686            }
2687        }
2688    };
2689}
2690
2691// ===========================================================================
2692// ConstantPad1d / ConstantPad2d / ConstantPad3d
2693// ===========================================================================
2694
2695/// Pads the last dimension of the input tensor with a constant value.
2696///
2697/// # Shape
2698/// - Input: `[*, L]`
2699/// - Output: `[*, L + pad_left + pad_right]`
2700#[derive(Debug)]
2701pub struct ConstantPad1d<T: Float> {
2702    /// Padding `(left, right)`.
2703    pub padding: (usize, usize),
2704    /// Constant fill value.
2705    pub value: T,
2706    training: bool,
2707}
2708
2709impl<T: Float> ConstantPad1d<T> {
2710    pub fn new(padding: (usize, usize), value: T) -> Self {
2711        Self {
2712            padding,
2713            value,
2714            training: true,
2715        }
2716    }
2717
2718    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2719        let data = input.data_vec()?;
2720        let (out, new_shape) = pad_1d_constant(
2721            &data,
2722            input.shape(),
2723            self.padding.0,
2724            self.padding.1,
2725            self.value,
2726        );
2727        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2728    }
2729}
2730
2731impl_padding_module!(ConstantPad1d);
2732
2733/// Pads the last 2 dimensions with a constant value.
2734///
2735/// # Shape
2736/// - Input: `[*, H, W]`
2737/// - Output: `[*, H + top + bottom, W + left + right]`
2738#[derive(Debug)]
2739pub struct ConstantPad2d<T: Float> {
2740    /// Padding `(left, right, top, bottom)`.
2741    pub padding: (usize, usize, usize, usize),
2742    /// Constant fill value.
2743    pub value: T,
2744    training: bool,
2745}
2746
2747impl<T: Float> ConstantPad2d<T> {
2748    pub fn new(padding: (usize, usize, usize, usize), value: T) -> Self {
2749        Self {
2750            padding,
2751            value,
2752            training: true,
2753        }
2754    }
2755
2756    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2757        if input.ndim() < 2 {
2758            return Err(FerrotorchError::InvalidArgument {
2759                message: format!(
2760                    "ConstantPad2d expects at least 2-D input, got {:?}",
2761                    input.shape()
2762                ),
2763            });
2764        }
2765        let data = input.data_vec()?;
2766        let (out, new_shape) = pad_2d_constant(
2767            &data,
2768            input.shape(),
2769            self.padding.0,
2770            self.padding.1,
2771            self.padding.2,
2772            self.padding.3,
2773            self.value,
2774        );
2775        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2776    }
2777}
2778
2779impl_padding_module!(ConstantPad2d);
2780
2781/// Pads the last 3 dimensions with a constant value.
2782///
2783/// # Shape
2784/// - Input: `[*, D, H, W]`
2785/// - Output: `[*, D + front + back, H + top + bottom, W + left + right]`
2786#[derive(Debug)]
2787pub struct ConstantPad3d<T: Float> {
2788    /// Padding `(left, right, top, bottom, front, back)`.
2789    pub padding: (usize, usize, usize, usize, usize, usize),
2790    /// Constant fill value.
2791    pub value: T,
2792    training: bool,
2793}
2794
2795impl<T: Float> ConstantPad3d<T> {
2796    pub fn new(padding: (usize, usize, usize, usize, usize, usize), value: T) -> Self {
2797        Self {
2798            padding,
2799            value,
2800            training: true,
2801        }
2802    }
2803
2804    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2805        if input.ndim() < 3 {
2806            return Err(FerrotorchError::InvalidArgument {
2807                message: format!(
2808                    "ConstantPad3d expects at least 3-D input, got {:?}",
2809                    input.shape()
2810                ),
2811            });
2812        }
2813        let data = input.data_vec()?;
2814        let (out, new_shape) = pad_3d_constant(
2815            &data,
2816            input.shape(),
2817            self.padding.0,
2818            self.padding.1,
2819            self.padding.2,
2820            self.padding.3,
2821            self.padding.4,
2822            self.padding.5,
2823            self.value,
2824        );
2825        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2826    }
2827}
2828
2829impl_padding_module!(ConstantPad3d);
2830
2831// ===========================================================================
2832// ZeroPad1d / ZeroPad2d / ZeroPad3d
2833// ===========================================================================
2834
2835/// Pads the last dimension with zeros.
2836#[derive(Debug)]
2837pub struct ZeroPad1d<T: Float> {
2838    pub padding: (usize, usize),
2839    training: bool,
2840    _phantom: std::marker::PhantomData<T>,
2841}
2842
2843impl<T: Float> ZeroPad1d<T> {
2844    pub fn new(padding: (usize, usize)) -> Self {
2845        Self {
2846            padding,
2847            training: true,
2848            _phantom: std::marker::PhantomData,
2849        }
2850    }
2851
2852    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2853        let data = input.data_vec()?;
2854        let zero = <T as num_traits::Zero>::zero();
2855        let (out, new_shape) =
2856            pad_1d_constant(&data, input.shape(), self.padding.0, self.padding.1, zero);
2857        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2858    }
2859}
2860
2861impl_padding_module!(ZeroPad1d);
2862
2863/// Pads the last 2 dimensions with zeros.
2864#[derive(Debug)]
2865pub struct ZeroPad2d<T: Float> {
2866    pub padding: (usize, usize, usize, usize),
2867    training: bool,
2868    _phantom: std::marker::PhantomData<T>,
2869}
2870
2871impl<T: Float> ZeroPad2d<T> {
2872    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2873        Self {
2874            padding,
2875            training: true,
2876            _phantom: std::marker::PhantomData,
2877        }
2878    }
2879
2880    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2881        if input.ndim() < 2 {
2882            return Err(FerrotorchError::InvalidArgument {
2883                message: format!(
2884                    "ZeroPad2d expects at least 2-D input, got {:?}",
2885                    input.shape()
2886                ),
2887            });
2888        }
2889        let data = input.data_vec()?;
2890        let zero = <T as num_traits::Zero>::zero();
2891        let (out, new_shape) = pad_2d_constant(
2892            &data,
2893            input.shape(),
2894            self.padding.0,
2895            self.padding.1,
2896            self.padding.2,
2897            self.padding.3,
2898            zero,
2899        );
2900        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2901    }
2902}
2903
2904impl_padding_module!(ZeroPad2d);
2905
2906/// Pads the last 3 dimensions with zeros.
2907#[derive(Debug)]
2908pub struct ZeroPad3d<T: Float> {
2909    pub padding: (usize, usize, usize, usize, usize, usize),
2910    training: bool,
2911    _phantom: std::marker::PhantomData<T>,
2912}
2913
2914impl<T: Float> ZeroPad3d<T> {
2915    pub fn new(padding: (usize, usize, usize, usize, usize, usize)) -> Self {
2916        Self {
2917            padding,
2918            training: true,
2919            _phantom: std::marker::PhantomData,
2920        }
2921    }
2922
2923    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2924        if input.ndim() < 3 {
2925            return Err(FerrotorchError::InvalidArgument {
2926                message: format!(
2927                    "ZeroPad3d expects at least 3-D input, got {:?}",
2928                    input.shape()
2929                ),
2930            });
2931        }
2932        let data = input.data_vec()?;
2933        let zero = <T as num_traits::Zero>::zero();
2934        let (out, new_shape) = pad_3d_constant(
2935            &data,
2936            input.shape(),
2937            self.padding.0,
2938            self.padding.1,
2939            self.padding.2,
2940            self.padding.3,
2941            self.padding.4,
2942            self.padding.5,
2943            zero,
2944        );
2945        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2946    }
2947}
2948
2949impl_padding_module!(ZeroPad3d);
2950
2951// ===========================================================================
2952// ReflectionPad1d / ReflectionPad2d / ReflectionPad3d
2953// ===========================================================================
2954
2955/// Pads the last dimension using reflection of the input boundary.
2956#[derive(Debug)]
2957pub struct ReflectionPad1d<T: Float> {
2958    pub padding: (usize, usize),
2959    training: bool,
2960    _phantom: std::marker::PhantomData<T>,
2961}
2962
2963impl<T: Float> ReflectionPad1d<T> {
2964    pub fn new(padding: (usize, usize)) -> Self {
2965        Self {
2966            padding,
2967            training: true,
2968            _phantom: std::marker::PhantomData,
2969        }
2970    }
2971
2972    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
2973        let data = input.data_vec()?;
2974        let (out, new_shape) =
2975            pad_1d_reflect(&data, input.shape(), self.padding.0, self.padding.1)?;
2976        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
2977    }
2978}
2979
2980impl_padding_module!(ReflectionPad1d);
2981
2982/// Pads the last 2 dimensions using reflection.
2983#[derive(Debug)]
2984pub struct ReflectionPad2d<T: Float> {
2985    pub padding: (usize, usize, usize, usize),
2986    training: bool,
2987    _phantom: std::marker::PhantomData<T>,
2988}
2989
2990impl<T: Float> ReflectionPad2d<T> {
2991    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
2992        Self {
2993            padding,
2994            training: true,
2995            _phantom: std::marker::PhantomData,
2996        }
2997    }
2998
2999    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3000        if input.ndim() < 2 {
3001            return Err(FerrotorchError::InvalidArgument {
3002                message: format!(
3003                    "ReflectionPad2d expects at least 2-D input, got {:?}",
3004                    input.shape()
3005                ),
3006            });
3007        }
3008        let data = input.data_vec()?;
3009        let (out, new_shape) = pad_2d_reflect(
3010            &data,
3011            input.shape(),
3012            self.padding.0,
3013            self.padding.1,
3014            self.padding.2,
3015            self.padding.3,
3016        )?;
3017        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3018    }
3019}
3020
3021impl_padding_module!(ReflectionPad2d);
3022
3023/// Pads the last 3 dimensions using reflection.
3024#[derive(Debug)]
3025pub struct ReflectionPad3d<T: Float> {
3026    pub padding: (usize, usize, usize, usize, usize, usize),
3027    training: bool,
3028    _phantom: std::marker::PhantomData<T>,
3029}
3030
3031impl<T: Float> ReflectionPad3d<T> {
3032    pub fn new(padding: (usize, usize, usize, usize, usize, usize)) -> Self {
3033        Self {
3034            padding,
3035            training: true,
3036            _phantom: std::marker::PhantomData,
3037        }
3038    }
3039
3040    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3041        if input.ndim() < 3 {
3042            return Err(FerrotorchError::InvalidArgument {
3043                message: format!(
3044                    "ReflectionPad3d expects at least 3-D input, got {:?}",
3045                    input.shape()
3046                ),
3047            });
3048        }
3049        let data = input.data_vec()?;
3050        let (out, new_shape) = pad_3d_reflect(
3051            &data,
3052            input.shape(),
3053            self.padding.0,
3054            self.padding.1,
3055            self.padding.2,
3056            self.padding.3,
3057            self.padding.4,
3058            self.padding.5,
3059        )?;
3060        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3061    }
3062}
3063
3064impl_padding_module!(ReflectionPad3d);
3065
3066// ===========================================================================
3067// ReplicationPad1d / ReplicationPad2d / ReplicationPad3d
3068// ===========================================================================
3069
3070/// Pads the last dimension by replicating the edge values.
3071#[derive(Debug)]
3072pub struct ReplicationPad1d<T: Float> {
3073    pub padding: (usize, usize),
3074    training: bool,
3075    _phantom: std::marker::PhantomData<T>,
3076}
3077
3078impl<T: Float> ReplicationPad1d<T> {
3079    pub fn new(padding: (usize, usize)) -> Self {
3080        Self {
3081            padding,
3082            training: true,
3083            _phantom: std::marker::PhantomData,
3084        }
3085    }
3086
3087    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3088        let data = input.data_vec()?;
3089        let (out, new_shape) =
3090            pad_1d_replicate(&data, input.shape(), self.padding.0, self.padding.1);
3091        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3092    }
3093}
3094
3095impl_padding_module!(ReplicationPad1d);
3096
3097/// Pads the last 2 dimensions by replicating edge values.
3098#[derive(Debug)]
3099pub struct ReplicationPad2d<T: Float> {
3100    pub padding: (usize, usize, usize, usize),
3101    training: bool,
3102    _phantom: std::marker::PhantomData<T>,
3103}
3104
3105impl<T: Float> ReplicationPad2d<T> {
3106    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
3107        Self {
3108            padding,
3109            training: true,
3110            _phantom: std::marker::PhantomData,
3111        }
3112    }
3113
3114    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3115        if input.ndim() < 2 {
3116            return Err(FerrotorchError::InvalidArgument {
3117                message: format!(
3118                    "ReplicationPad2d expects at least 2-D input, got {:?}",
3119                    input.shape()
3120                ),
3121            });
3122        }
3123        let data = input.data_vec()?;
3124        let (out, new_shape) = pad_2d_replicate(
3125            &data,
3126            input.shape(),
3127            self.padding.0,
3128            self.padding.1,
3129            self.padding.2,
3130            self.padding.3,
3131        );
3132        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3133    }
3134}
3135
3136impl_padding_module!(ReplicationPad2d);
3137
3138/// Pads the last 3 dimensions by replicating edge values.
3139#[derive(Debug)]
3140pub struct ReplicationPad3d<T: Float> {
3141    pub padding: (usize, usize, usize, usize, usize, usize),
3142    training: bool,
3143    _phantom: std::marker::PhantomData<T>,
3144}
3145
3146impl<T: Float> ReplicationPad3d<T> {
3147    pub fn new(padding: (usize, usize, usize, usize, usize, usize)) -> Self {
3148        Self {
3149            padding,
3150            training: true,
3151            _phantom: std::marker::PhantomData,
3152        }
3153    }
3154
3155    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3156        if input.ndim() < 3 {
3157            return Err(FerrotorchError::InvalidArgument {
3158                message: format!(
3159                    "ReplicationPad3d expects at least 3-D input, got {:?}",
3160                    input.shape()
3161                ),
3162            });
3163        }
3164        let data = input.data_vec()?;
3165        let (out, new_shape) = pad_3d_replicate(
3166            &data,
3167            input.shape(),
3168            self.padding.0,
3169            self.padding.1,
3170            self.padding.2,
3171            self.padding.3,
3172            self.padding.4,
3173            self.padding.5,
3174        );
3175        Tensor::from_storage(TensorStorage::cpu(out), new_shape, false)
3176    }
3177}
3178
3179impl_padding_module!(ReplicationPad3d);
3180
3181// ===========================================================================
3182// CircularPad — wraps data circularly (periodic boundary conditions)
3183// ===========================================================================
3184
3185/// 1-D circular padding: wraps the input circularly.
3186///
3187/// Input: [N, C, W]. Pads the W dimension with circular (periodic) values.
3188/// Matches PyTorch's `nn.CircularPad1d`.
3189#[derive(Debug, Clone)]
3190pub struct CircularPad1d<T: Float> {
3191    pub padding: (usize, usize),
3192    training: bool,
3193    _phantom: std::marker::PhantomData<T>,
3194}
3195
3196impl<T: Float> CircularPad1d<T> {
3197    pub fn new(padding: (usize, usize)) -> Self {
3198        Self {
3199            padding,
3200            training: true,
3201            _phantom: std::marker::PhantomData,
3202        }
3203    }
3204
3205    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3206        if input.ndim() != 3 {
3207            return Err(FerrotorchError::InvalidArgument {
3208                message: format!(
3209                    "CircularPad1d: expected 3-D input [N,C,W], got {:?}",
3210                    input.shape()
3211                ),
3212            });
3213        }
3214        if input.is_cuda() {
3215            return Err(FerrotorchError::NotImplementedOnCuda {
3216                op: "CircularPad1d",
3217            });
3218        }
3219        let shape = input.shape();
3220        let (n, c, w) = (shape[0], shape[1], shape[2]);
3221        let (pl, pr) = self.padding;
3222        let new_w = w + pl + pr;
3223        let data = input.data()?;
3224        let zero = <T as num_traits::Zero>::zero();
3225        let mut out = vec![zero; n * c * new_w];
3226
3227        for batch in 0..n {
3228            for ch in 0..c {
3229                for ow in 0..new_w {
3230                    let iw = ((ow as isize - pl as isize).rem_euclid(w as isize)) as usize;
3231                    out[batch * c * new_w + ch * new_w + ow] = data[batch * c * w + ch * w + iw];
3232                }
3233            }
3234        }
3235
3236        Tensor::from_storage(TensorStorage::cpu(out), vec![n, c, new_w], false)
3237    }
3238}
3239
3240impl<T: Float> Default for CircularPad1d<T> {
3241    fn default() -> Self {
3242        Self::new((0, 0))
3243    }
3244}
3245
3246impl_padding_module!(CircularPad1d);
3247
3248/// 2-D circular padding. Input: [N, C, H, W].
3249/// Matches PyTorch's `nn.CircularPad2d`.
3250#[derive(Debug, Clone)]
3251pub struct CircularPad2d<T: Float> {
3252    pub padding: (usize, usize, usize, usize),
3253    training: bool,
3254    _phantom: std::marker::PhantomData<T>,
3255}
3256
3257impl<T: Float> CircularPad2d<T> {
3258    pub fn new(padding: (usize, usize, usize, usize)) -> Self {
3259        Self {
3260            padding,
3261            training: true,
3262            _phantom: std::marker::PhantomData,
3263        }
3264    }
3265
3266    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3267        if input.ndim() != 4 {
3268            return Err(FerrotorchError::InvalidArgument {
3269                message: format!(
3270                    "CircularPad2d: expected 4-D input [N,C,H,W], got {:?}",
3271                    input.shape()
3272                ),
3273            });
3274        }
3275        if input.is_cuda() {
3276            return Err(FerrotorchError::NotImplementedOnCuda {
3277                op: "CircularPad2d",
3278            });
3279        }
3280        let shape = input.shape();
3281        let (n, c, h, w) = (shape[0], shape[1], shape[2], shape[3]);
3282        let (pl, pr, pt, pb) = self.padding;
3283        let new_h = h + pt + pb;
3284        let new_w = w + pl + pr;
3285        let data = input.data()?;
3286        let zero = <T as num_traits::Zero>::zero();
3287        let mut out = vec![zero; n * c * new_h * new_w];
3288
3289        for batch in 0..n {
3290            for ch in 0..c {
3291                for oh in 0..new_h {
3292                    let ih = ((oh as isize - pt as isize).rem_euclid(h as isize)) as usize;
3293                    for ow in 0..new_w {
3294                        let iw = ((ow as isize - pl as isize).rem_euclid(w as isize)) as usize;
3295                        out[batch * c * new_h * new_w + ch * new_h * new_w + oh * new_w + ow] =
3296                            data[batch * c * h * w + ch * h * w + ih * w + iw];
3297                    }
3298                }
3299            }
3300        }
3301
3302        Tensor::from_storage(TensorStorage::cpu(out), vec![n, c, new_h, new_w], false)
3303    }
3304}
3305
3306impl<T: Float> Default for CircularPad2d<T> {
3307    fn default() -> Self {
3308        Self::new((0, 0, 0, 0))
3309    }
3310}
3311
3312impl_padding_module!(CircularPad2d);
3313
3314/// 3-D circular padding. Input: [N, C, D, H, W].
3315/// Matches PyTorch's `nn.CircularPad3d`.
3316#[derive(Debug, Clone)]
3317pub struct CircularPad3d<T: Float> {
3318    pub padding: (usize, usize, usize, usize, usize, usize),
3319    training: bool,
3320    _phantom: std::marker::PhantomData<T>,
3321}
3322
3323impl<T: Float> CircularPad3d<T> {
3324    pub fn new(padding: (usize, usize, usize, usize, usize, usize)) -> Self {
3325        Self {
3326            padding,
3327            training: true,
3328            _phantom: std::marker::PhantomData,
3329        }
3330    }
3331
3332    fn pad(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
3333        if input.ndim() != 5 {
3334            return Err(FerrotorchError::InvalidArgument {
3335                message: format!(
3336                    "CircularPad3d: expected 5-D input [N,C,D,H,W], got {:?}",
3337                    input.shape()
3338                ),
3339            });
3340        }
3341        if input.is_cuda() {
3342            return Err(FerrotorchError::NotImplementedOnCuda {
3343                op: "CircularPad3d",
3344            });
3345        }
3346        let shape = input.shape();
3347        let (n, c, d, h, w) = (shape[0], shape[1], shape[2], shape[3], shape[4]);
3348        let (pl, pr, pt, pb, pf, pk) = self.padding;
3349        let (new_d, new_h, new_w) = (d + pf + pk, h + pt + pb, w + pl + pr);
3350        let data = input.data()?;
3351        let zero = <T as num_traits::Zero>::zero();
3352        let mut out = vec![zero; n * c * new_d * new_h * new_w];
3353
3354        for batch in 0..n {
3355            for ch in 0..c {
3356                for od in 0..new_d {
3357                    let id = ((od as isize - pf as isize).rem_euclid(d as isize)) as usize;
3358                    for oh in 0..new_h {
3359                        let ih = ((oh as isize - pt as isize).rem_euclid(h as isize)) as usize;
3360                        for ow in 0..new_w {
3361                            let iw = ((ow as isize - pl as isize).rem_euclid(w as isize)) as usize;
3362                            out[batch * c * new_d * new_h * new_w
3363                                + ch * new_d * new_h * new_w
3364                                + od * new_h * new_w
3365                                + oh * new_w
3366                                + ow] = data
3367                                [batch * c * d * h * w + ch * d * h * w + id * h * w + ih * w + iw];
3368                        }
3369                    }
3370                }
3371            }
3372        }
3373
3374        Tensor::from_storage(
3375            TensorStorage::cpu(out),
3376            vec![n, c, new_d, new_h, new_w],
3377            false,
3378        )
3379    }
3380}
3381
3382impl<T: Float> Default for CircularPad3d<T> {
3383    fn default() -> Self {
3384        Self::new((0, 0, 0, 0, 0, 0))
3385    }
3386}
3387
3388impl_padding_module!(CircularPad3d);
3389
3390// ===========================================================================
3391// Tests
3392// ===========================================================================
3393
3394#[cfg(test)]
3395mod tests {
3396    use super::*;
3397    use crate::module::Module;
3398
3399    fn t(data: &[f32], shape: &[usize]) -> Tensor<f32> {
3400        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
3401    }
3402
3403    fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
3404        assert_eq!(
3405            actual.len(),
3406            expected.len(),
3407            "length mismatch: {} vs {}",
3408            actual.len(),
3409            expected.len()
3410        );
3411        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
3412            assert!((a - e).abs() < tol, "index {i}: actual={a} expected={e}");
3413        }
3414    }
3415
3416    // -----------------------------------------------------------------------
3417    // ConstantPad1d
3418    // -----------------------------------------------------------------------
3419
3420    #[test]
3421    fn test_constant_pad1d_basic() {
3422        let pad = ConstantPad1d::<f32>::new((2, 3), 9.0);
3423        let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3424        let output = pad.forward(&input).unwrap();
3425        assert_eq!(output.shape(), &[1, 1, 8]);
3426        assert_close(
3427            output.data().unwrap(),
3428            &[9.0, 9.0, 1.0, 2.0, 3.0, 9.0, 9.0, 9.0],
3429            1e-7,
3430        );
3431    }
3432
3433    // -----------------------------------------------------------------------
3434    // ZeroPad1d
3435    // -----------------------------------------------------------------------
3436
3437    #[test]
3438    fn test_zero_pad1d() {
3439        let pad = ZeroPad1d::<f32>::new((1, 2));
3440        let input = t(&[1.0, 2.0, 3.0], &[3]);
3441        let output = pad.forward(&input).unwrap();
3442        assert_eq!(output.shape(), &[6]);
3443        assert_close(
3444            output.data().unwrap(),
3445            &[0.0, 1.0, 2.0, 3.0, 0.0, 0.0],
3446            1e-7,
3447        );
3448    }
3449
3450    // -----------------------------------------------------------------------
3451    // ZeroPad2d
3452    // -----------------------------------------------------------------------
3453
3454    #[test]
3455    fn test_zero_pad2d() {
3456        let pad = ZeroPad2d::<f32>::new((1, 1, 1, 1));
3457        let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
3458        let output = pad.forward(&input).unwrap();
3459        assert_eq!(output.shape(), &[1, 1, 4, 4]);
3460        #[rustfmt::skip]
3461        let expected = [
3462            0.0, 0.0, 0.0, 0.0,
3463            0.0, 1.0, 2.0, 0.0,
3464            0.0, 3.0, 4.0, 0.0,
3465            0.0, 0.0, 0.0, 0.0,
3466        ];
3467        assert_close(output.data().unwrap(), &expected, 1e-7);
3468    }
3469
3470    // -----------------------------------------------------------------------
3471    // ZeroPad3d
3472    // -----------------------------------------------------------------------
3473
3474    #[test]
3475    fn test_zero_pad3d_shape() {
3476        let pad = ZeroPad3d::<f32>::new((1, 1, 1, 1, 1, 1));
3477        let input = t(&[1.0; 2 * 2 * 2], &[1, 1, 2, 2, 2]);
3478        let output = pad.forward(&input).unwrap();
3479        assert_eq!(output.shape(), &[1, 1, 4, 4, 4]);
3480    }
3481
3482    // -----------------------------------------------------------------------
3483    // ReflectionPad1d
3484    // -----------------------------------------------------------------------
3485
3486    #[test]
3487    fn test_reflection_pad1d() {
3488        let pad = ReflectionPad1d::<f32>::new((2, 2));
3489        // input = [1, 2, 3, 4]
3490        let input = t(&[1.0, 2.0, 3.0, 4.0], &[4]);
3491        let output = pad.forward(&input).unwrap();
3492        assert_eq!(output.shape(), &[8]);
3493        // Reflect left: [3, 2, | 1, 2, 3, 4 | 3, 2]
3494        assert_close(
3495            output.data().unwrap(),
3496            &[3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 3.0, 2.0],
3497            1e-7,
3498        );
3499    }
3500
3501    #[test]
3502    fn test_reflection_pad1d_too_large() {
3503        let pad = ReflectionPad1d::<f32>::new((4, 0));
3504        let input = t(&[1.0, 2.0, 3.0], &[3]); // size 3, pad 4 >= 3
3505        assert!(pad.forward(&input).is_err());
3506    }
3507
3508    // -----------------------------------------------------------------------
3509    // ReflectionPad2d
3510    // -----------------------------------------------------------------------
3511
3512    #[test]
3513    fn test_reflection_pad2d() {
3514        let pad = ReflectionPad2d::<f32>::new((1, 1, 1, 1));
3515        #[rustfmt::skip]
3516        let input = t(&[
3517            1.0, 2.0, 3.0,
3518            4.0, 5.0, 6.0,
3519            7.0, 8.0, 9.0,
3520        ], &[1, 1, 3, 3]);
3521        let output = pad.forward(&input).unwrap();
3522        assert_eq!(output.shape(), &[1, 1, 5, 5]);
3523        // Corner (0,0) should reflect to (1,1) in src = 5.0
3524        let out = output.data().unwrap();
3525        assert_close(&out[0..1], &[5.0], 1e-7); // top-left corner
3526    }
3527
3528    // -----------------------------------------------------------------------
3529    // ReplicationPad1d
3530    // -----------------------------------------------------------------------
3531
3532    #[test]
3533    fn test_replication_pad1d() {
3534        let pad = ReplicationPad1d::<f32>::new((2, 3));
3535        let input = t(&[1.0, 2.0, 3.0], &[3]);
3536        let output = pad.forward(&input).unwrap();
3537        assert_eq!(output.shape(), &[8]);
3538        assert_close(
3539            output.data().unwrap(),
3540            &[1.0, 1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0],
3541            1e-7,
3542        );
3543    }
3544
3545    // -----------------------------------------------------------------------
3546    // ReplicationPad2d
3547    // -----------------------------------------------------------------------
3548
3549    #[test]
3550    fn test_replication_pad2d() {
3551        let pad = ReplicationPad2d::<f32>::new((1, 1, 1, 1));
3552        #[rustfmt::skip]
3553        let input = t(&[
3554            1.0, 2.0,
3555            3.0, 4.0,
3556        ], &[1, 1, 2, 2]);
3557        let output = pad.forward(&input).unwrap();
3558        assert_eq!(output.shape(), &[1, 1, 4, 4]);
3559        #[rustfmt::skip]
3560        let expected = [
3561            1.0, 1.0, 2.0, 2.0,
3562            1.0, 1.0, 2.0, 2.0,
3563            3.0, 3.0, 4.0, 4.0,
3564            3.0, 3.0, 4.0, 4.0,
3565        ];
3566        assert_close(output.data().unwrap(), &expected, 1e-7);
3567    }
3568
3569    // -----------------------------------------------------------------------
3570    // ConstantPad2d
3571    // -----------------------------------------------------------------------
3572
3573    #[test]
3574    fn test_constant_pad2d() {
3575        let pad = ConstantPad2d::<f32>::new((1, 1, 1, 1), -1.0);
3576        let input = t(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
3577        let output = pad.forward(&input).unwrap();
3578        assert_eq!(output.shape(), &[4, 4]);
3579        #[rustfmt::skip]
3580        let expected = [
3581            -1.0, -1.0, -1.0, -1.0,
3582            -1.0, 5.0, 6.0, -1.0,
3583            -1.0, 7.0, 8.0, -1.0,
3584            -1.0, -1.0, -1.0, -1.0,
3585        ];
3586        assert_close(output.data().unwrap(), &expected, 1e-7);
3587    }
3588
3589    // -----------------------------------------------------------------------
3590    // ConstantPad3d
3591    // -----------------------------------------------------------------------
3592
3593    #[test]
3594    fn test_constant_pad3d_shape() {
3595        let pad = ConstantPad3d::<f32>::new((1, 2, 1, 2, 1, 2), 0.0);
3596        let input = t(&vec![1.0; 3 * 4 * 5], &[1, 1, 3, 4, 5]);
3597        let output = pad.forward(&input).unwrap();
3598        assert_eq!(output.shape(), &[1, 1, 6, 7, 8]);
3599    }
3600
3601    // -----------------------------------------------------------------------
3602    // Circular padding (1D)
3603    // -----------------------------------------------------------------------
3604
3605    #[test]
3606    fn test_circular_pad_1d() {
3607        // input = [1, 2, 3, 4], pad_left=1, pad_right=2
3608        // circular: [4, 1, 2, 3, 4, 1, 2]
3609        let data = [1.0f32, 2.0, 3.0, 4.0];
3610        let (out, new_shape) = pad_1d_circular(&data, &[4], 1, 2);
3611        assert_eq!(new_shape, &[7]);
3612        assert_close(&out, &[4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0], 1e-7);
3613    }
3614
3615    // -----------------------------------------------------------------------
3616    // Padding mode enum
3617    // -----------------------------------------------------------------------
3618
3619    #[test]
3620    fn test_padding_mode_eq() {
3621        assert_eq!(PaddingMode::Zeros, PaddingMode::Zeros);
3622        assert_ne!(PaddingMode::Zeros, PaddingMode::Reflect);
3623    }
3624
3625    // -----------------------------------------------------------------------
3626    // Module trait: no parameters
3627    // -----------------------------------------------------------------------
3628
3629    #[test]
3630    fn test_padding_module_no_params() {
3631        let pad = ZeroPad2d::<f32>::new((1, 1, 1, 1));
3632        assert!(pad.parameters().is_empty());
3633        assert!(pad.named_parameters().is_empty());
3634    }
3635
3636    #[test]
3637    fn test_padding_module_train_eval() {
3638        let mut pad = ReflectionPad1d::<f32>::new((1, 1));
3639        assert!(pad.is_training());
3640        pad.eval();
3641        assert!(!pad.is_training());
3642        pad.train();
3643        assert!(pad.is_training());
3644    }
3645
3646    // -----------------------------------------------------------------------
3647    // Degenerate (numel-0) constant pad — regression for #1551.
3648    //
3649    // op_db emits pad samples whose input has an empty data buffer paired
3650    // with a non-empty *declared* last dim (e.g. shape `[0, 3]`: numel 0,
3651    // inner 3). Previously `pad_{1,2,3}d_constant` forced rows/outer to 1 and
3652    // then read `inner`/`w` elements from the empty `data` slice, panicking
3653    // with "range end index N out of range for slice of length 0" at the
3654    // `copy_from_slice`. Upstream `torch.nn.functional.pad`
3655    // (`aten/src/ATen/native/PadNd.cpp:94-106`) allocates the padded output,
3656    // `fill_(value)`s it, then `copy_`s the (empty) source — a no-op — so the
3657    // result is the correctly-shaped, value-filled tensor. These assert the
3658    // fixed behaviour: no panic + correct output shape on numel-0 input.
3659    // -----------------------------------------------------------------------
3660
3661    #[test]
3662    fn test_constant_pad1d_empty_numel_no_panic() {
3663        // shape [0, 3]: numel 0 but inner = 3. data buffer is empty.
3664        let (out, new_shape) = pad_1d_constant::<f32>(&[], &[0, 3], 2, 3, 7.0);
3665        // last dim padded 3 -> 3+2+3 = 8; outer 0-dim with forced row count 1.
3666        assert_eq!(new_shape, vec![0, 8]);
3667        // value-filled output, no source copied in.
3668        assert!(out.iter().all(|&v| v == 7.0));
3669    }
3670
3671    #[test]
3672    fn test_constant_pad2d_empty_numel_no_panic() {
3673        // shape [0, 2, 3]: numel 0, h = 2, w = 3, empty data.
3674        let (out, new_shape) = pad_2d_constant::<f32>(&[], &[0, 2, 3], 1, 1, 1, 1, 5.0);
3675        assert_eq!(new_shape, vec![0, 4, 5]);
3676        assert!(out.iter().all(|&v| v == 5.0));
3677    }
3678
3679    #[test]
3680    fn test_constant_pad3d_empty_numel_no_panic() {
3681        // shape [0, 2, 2, 3]: numel 0, d = 2, h = 2, w = 3, empty data.
3682        let (out, new_shape) = pad_3d_constant::<f32>(&[], &[0, 2, 2, 3], 1, 1, 1, 1, 1, 1, 3.0);
3683        assert_eq!(new_shape, vec![0, 4, 4, 5]);
3684        assert!(out.iter().all(|&v| v == 3.0));
3685    }
3686
3687    // -----------------------------------------------------------------------
3688    // Regression: `functional_pad_{1,2,3}d` constant-mode must use `value`.
3689    //
3690    // The runner maps torch `mode="constant"` -> `PaddingMode::Zeros` and passes
3691    // the `value` kwarg through. Pre-fix the `Zeros` arm hardcoded `T::zero()`
3692    // and dropped `value` (`let _ = value;`), so `F.pad(x, p, "constant", 2.0)`
3693    // filled 0 instead of 2 — 256 parity-sweep failures (ferrotorch=0 vs
3694    // torch=2). Upstream `aten/src/ATen/native/PadNd.cpp:94` does
3695    // `output.fill_(value)` before copying the source. #1553.
3696    // -----------------------------------------------------------------------
3697
3698    #[test]
3699    fn test_functional_pad_1d_constant_uses_value() {
3700        let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3701        let out = functional_pad_1d(&input, 1, 1, PaddingMode::Zeros, 2.0).unwrap();
3702        assert_eq!(out.shape(), &[1, 1, 5]);
3703        // Padded region (first + last) must be the fill `value` 2.0, not 0.0.
3704        assert_close(out.data().unwrap(), &[2.0, 1.0, 2.0, 3.0, 2.0], 1e-7);
3705    }
3706
3707    #[test]
3708    fn test_functional_pad_2d_constant_uses_value() {
3709        // 1x1x2x2 input, pad (left, right, top, bottom) = (1, 1, 1, 1).
3710        let input = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 2, 2]);
3711        let out = functional_pad_2d(&input, 1, 1, 1, 1, PaddingMode::Zeros, 2.0).unwrap();
3712        assert_eq!(out.shape(), &[1, 1, 4, 4]);
3713        #[rustfmt::skip]
3714        let expected = [
3715            2.0, 2.0, 2.0, 2.0,
3716            2.0, 1.0, 2.0, 2.0,
3717            2.0, 3.0, 4.0, 2.0,
3718            2.0, 2.0, 2.0, 2.0,
3719        ];
3720        assert_close(out.data().unwrap(), &expected, 1e-7);
3721        // The border is the fill value; no padded cell is 0.
3722        assert!(out.data().unwrap().iter().all(|&v| v != 0.0));
3723    }
3724
3725    #[test]
3726    fn test_functional_pad_3d_constant_uses_value() {
3727        // 1x1x1x1x1 input, pad all six axes by 0 except left/right by 1.
3728        let input = t(&[5.0], &[1, 1, 1, 1, 1]);
3729        let out = functional_pad_3d(&input, 1, 1, 0, 0, 0, 0, PaddingMode::Zeros, 2.0).unwrap();
3730        assert_eq!(out.shape(), &[1, 1, 1, 1, 3]);
3731        assert_close(out.data().unwrap(), &[2.0, 5.0, 2.0], 1e-7);
3732    }
3733
3734    // -----------------------------------------------------------------------
3735    // Autograd-aware functional pad (Pad1dBackward / Pad3dBackward) — #1443.
3736    //
3737    // These are the pre-pad helpers Conv1d/Conv3d route non-zero padding_modes
3738    // through; a pad returning requires_grad=false severs autograd (the #1550
3739    // bug class the 2-D path already fixed). Expected gradients are from a live
3740    // PyTorch 2.11 `F.pad(...).sum().backward()` oracle (R-CHAR-3); the oracle
3741    // script is in the #1443 commit body.
3742    // -----------------------------------------------------------------------
3743
3744    /// Helper: leaf tensor that requires grad.
3745    fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
3746        Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
3747    }
3748
3749    /// `functional_pad_1d` Reflect attaches `Pad1dBackward` and scatter-adds the
3750    /// grad back onto the source row. torch: F.pad([1,2,3,4], (2,2), 'reflect')
3751    /// -> out [3,2,1,2,3,4,3,2]; sum().backward() grad_input = [1,3,3,1].
3752    #[test]
3753    fn test_functional_pad_1d_reflect_backward_matches_torch() {
3754        let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
3755        let y = functional_pad_1d(&x, 2, 2, PaddingMode::Reflect, 0.0).unwrap();
3756        assert_eq!(y.shape(), &[1, 1, 8]);
3757        assert!(
3758            y.grad_fn().is_some(),
3759            "functional_pad_1d Reflect lost grad_fn — would sever Conv1d autograd (#1550 class)"
3760        );
3761        assert_eq!(y.grad_fn().unwrap().name(), "Pad1dBackward");
3762        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3763        ferrotorch_core::backward(&sum).unwrap();
3764        let g = x.grad().unwrap().expect("grad must be populated");
3765        assert_close(g.data().unwrap(), &[1.0, 3.0, 3.0, 1.0], 1e-5);
3766    }
3767
3768    /// `functional_pad_3d` Circular attaches `Pad3dBackward`. torch: a circular
3769    /// pad of (1,1,1,1,1,1) on a 2x2x2 volume wraps every cell exactly 8 times,
3770    /// so the all-ones grad_output backprops to a uniform grad of 8.
3771    #[test]
3772    fn test_functional_pad_3d_circular_backward_matches_torch() {
3773        let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
3774        let x = leaf(&x_data, &[1, 1, 2, 2, 2]);
3775        let y = functional_pad_3d(&x, 1, 1, 1, 1, 1, 1, PaddingMode::Circular, 0.0).unwrap();
3776        assert_eq!(y.shape(), &[1, 1, 4, 4, 4]);
3777        assert!(y.grad_fn().is_some());
3778        assert_eq!(y.grad_fn().unwrap().name(), "Pad3dBackward");
3779        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3780        ferrotorch_core::backward(&sum).unwrap();
3781        let g = x.grad().unwrap().expect("grad must be populated");
3782        assert_close(g.data().unwrap(), &[8.0; 8], 1e-5);
3783    }
3784
3785    // -----------------------------------------------------------------------
3786    // Negative (crop) padding — `torch.nn.functional.pad` with negative pad
3787    // amounts CROPS that side instead of adding. Only the constant
3788    // (`PaddingMode::Zeros`) path supports it; upstream
3789    // `aten/src/ATen/native/PadNd.cpp:29-108` (`constant_pad_nd`) narrows the
3790    // input for negative pads, fills the output with `value`, and copies the
3791    // cropped input into the positive-pad window. Reflect/replicate/circular
3792    // reject negative pads (PadNd.cpp:221-242). #1611.
3793    //
3794    // All expected forward + backward (sum().backward()) values below are from
3795    // a live PyTorch 2.11 oracle (R-CHAR-3); the deriving script is in the
3796    // #1611 commit body. Each block names the exact `F.pad(...)` call it pins.
3797    // -----------------------------------------------------------------------
3798
3799    /// torch: `F.pad(torch.tensor([[[1,2,3,4,5]]]), [-1,-1], "constant")`
3800    /// -> out [2,3,4]; sum().backward() grad_input = [0,1,1,1,0].
3801    #[test]
3802    fn test_functional_pad_1d_signed_crop_both_matches_torch() {
3803        let x = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0], &[1, 1, 5]);
3804        let y = functional_pad_1d_signed(&x, -1, -1, PaddingMode::Zeros, 0.0).unwrap();
3805        assert_eq!(y.shape(), &[1, 1, 3]);
3806        assert_close(y.data().unwrap(), &[2.0, 3.0, 4.0], 1e-7);
3807        assert_eq!(y.grad_fn().unwrap().name(), "PadNdSignedBackward");
3808        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3809        ferrotorch_core::backward(&sum).unwrap();
3810        let g = x.grad().unwrap().expect("grad must be populated");
3811        assert_close(g.data().unwrap(), &[0.0, 1.0, 1.0, 1.0, 0.0], 1e-7);
3812    }
3813
3814    /// Mixed signs: torch
3815    /// `F.pad(torch.tensor([[[1,2,3,4]]]), [-1,2], "constant", value=9)`
3816    /// -> out [2,3,4,9,9] (crop 1 from start, add 2 fill at end);
3817    /// sum().backward() grad_input = [0,1,1,1].
3818    #[test]
3819    fn test_functional_pad_1d_signed_mixed_matches_torch() {
3820        let x = leaf(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
3821        let y = functional_pad_1d_signed(&x, -1, 2, PaddingMode::Zeros, 9.0).unwrap();
3822        assert_eq!(y.shape(), &[1, 1, 5]);
3823        assert_close(y.data().unwrap(), &[2.0, 3.0, 4.0, 9.0, 9.0], 1e-7);
3824        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3825        ferrotorch_core::backward(&sum).unwrap();
3826        let g = x.grad().unwrap().expect("grad must be populated");
3827        assert_close(g.data().unwrap(), &[0.0, 1.0, 1.0, 1.0], 1e-7);
3828    }
3829
3830    /// 2-D crop: torch `F.pad(3x3, [-1,0, 0,-1], "constant")` crops the right
3831    /// column (last dim) and the bottom row (2nd-last) -> 2x2 [[2,3],[5,6]];
3832    /// sum().backward() grad = [[0,1,1],[0,1,1],[0,0,0]] (flattened).
3833    #[test]
3834    fn test_functional_pad_2d_signed_crop_matches_torch() {
3835        #[rustfmt::skip]
3836        let x = leaf(&[
3837            1.0, 2.0, 3.0,
3838            4.0, 5.0, 6.0,
3839            7.0, 8.0, 9.0,
3840        ], &[1, 1, 3, 3]);
3841        let y = functional_pad_2d_signed(&x, -1, 0, 0, -1, PaddingMode::Zeros, 0.0).unwrap();
3842        assert_eq!(y.shape(), &[1, 1, 2, 2]);
3843        assert_close(y.data().unwrap(), &[2.0, 3.0, 5.0, 6.0], 1e-7);
3844        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3845        ferrotorch_core::backward(&sum).unwrap();
3846        let g = x.grad().unwrap().expect("grad must be populated");
3847        assert_close(
3848            g.data().unwrap(),
3849            &[0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0],
3850            1e-7,
3851        );
3852    }
3853
3854    /// 2-D mixed signs: torch
3855    /// `F.pad(2x3, [-1,2, 1,-1], "constant", value=7)` (last dim crop1/add2,
3856    /// 2nd-last add1/crop1) -> 2x4 [[7,7,7,7],[2,3,7,7]];
3857    /// sum().backward() grad = [[0,1,1],[0,0,0]] (flattened).
3858    #[test]
3859    fn test_functional_pad_2d_signed_mixed_matches_torch() {
3860        #[rustfmt::skip]
3861        let x = leaf(&[
3862            1.0, 2.0, 3.0,
3863            4.0, 5.0, 6.0,
3864        ], &[1, 1, 2, 3]);
3865        let y = functional_pad_2d_signed(&x, -1, 2, 1, -1, PaddingMode::Zeros, 7.0).unwrap();
3866        assert_eq!(y.shape(), &[1, 1, 2, 4]);
3867        #[rustfmt::skip]
3868        let expected = [
3869            7.0, 7.0, 7.0, 7.0,
3870            2.0, 3.0, 7.0, 7.0,
3871        ];
3872        assert_close(y.data().unwrap(), &expected, 1e-7);
3873        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3874        ferrotorch_core::backward(&sum).unwrap();
3875        let g = x.grad().unwrap().expect("grad must be populated");
3876        assert_close(g.data().unwrap(), &[0.0, 1.0, 1.0, 0.0, 0.0, 0.0], 1e-7);
3877    }
3878
3879    /// 3-D crop: torch `F.pad(2x2x2 [1..8], [-1,0, 0,-1, -1,0], "constant")`
3880    /// (W crop right, H crop bottom, D crop front) -> 1x1x1 [6];
3881    /// sum().backward() grad = [0,0,0,0,0,1,0,0].
3882    #[test]
3883    fn test_functional_pad_3d_signed_crop_matches_torch() {
3884        let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
3885        let x = leaf(&x_data, &[1, 1, 2, 2, 2]);
3886        let y = functional_pad_3d_signed(&x, -1, 0, 0, -1, -1, 0, PaddingMode::Zeros, 0.0).unwrap();
3887        assert_eq!(y.shape(), &[1, 1, 1, 1, 1]);
3888        assert_close(y.data().unwrap(), &[6.0], 1e-7);
3889        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3890        ferrotorch_core::backward(&sum).unwrap();
3891        let g = x.grad().unwrap().expect("grad must be populated");
3892        assert_close(
3893            g.data().unwrap(),
3894            &[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0],
3895            1e-7,
3896        );
3897    }
3898
3899    /// 3-D mixed signs incl. positive adds: torch
3900    /// `F.pad(2x2x2 [1..8], [1,-1, 0,1, -1,2], "constant", value=3)`
3901    /// -> 3x3x2; sum().backward() grad = [0,0,0,0,1,0,1,0].
3902    #[test]
3903    fn test_functional_pad_3d_signed_mixed_matches_torch() {
3904        let x_data: Vec<f32> = (1..=8).map(|v| v as f32).collect();
3905        let x = leaf(&x_data, &[1, 1, 2, 2, 2]);
3906        let y = functional_pad_3d_signed(&x, 1, -1, 0, 1, -1, 2, PaddingMode::Zeros, 3.0).unwrap();
3907        assert_eq!(y.shape(), &[1, 1, 3, 3, 2]);
3908        #[rustfmt::skip]
3909        let expected = [
3910            3.0, 5.0, 3.0, 7.0, 3.0, 3.0, 3.0, 3.0, 3.0,
3911            3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0, 3.0,
3912        ];
3913        assert_close(y.data().unwrap(), &expected, 1e-7);
3914        let sum = ferrotorch_core::grad_fns::reduction::sum(&y).unwrap();
3915        ferrotorch_core::backward(&sum).unwrap();
3916        let g = x.grad().unwrap().expect("grad must be populated");
3917        assert_close(
3918            g.data().unwrap(),
3919            &[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0],
3920            1e-7,
3921        );
3922    }
3923
3924    /// Over-crop: torch raises `RuntimeError: narrow(): length must be
3925    /// non-negative` when a single side crops more than the dim holds
3926    /// (`F.pad([[[1,2,3]]], [-4,0])`) or the combined net size is negative
3927    /// (`F.pad([[[1,2,3]]], [-2,-2])`). ferrotorch returns `InvalidArgument`.
3928    #[test]
3929    fn test_functional_pad_1d_signed_over_crop_errors() {
3930        // Single side over-crops (left 4 from size 3).
3931        let x = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3932        assert!(
3933            functional_pad_1d_signed(&x, -4, 0, PaddingMode::Zeros, 0.0).is_err(),
3934            "single-side over-crop must error like torch narrow()"
3935        );
3936        // Combined net negative size (left 2 + right 2 from size 3 -> -1).
3937        assert!(
3938            functional_pad_1d_signed(&x, -2, -2, PaddingMode::Zeros, 0.0).is_err(),
3939            "combined net-negative crop must error like torch"
3940        );
3941        // Right side over-crops after left (left 1 -> size 2, right 3 -> -1).
3942        assert!(
3943            functional_pad_1d_signed(&x, -1, -3, PaddingMode::Zeros, 0.0).is_err(),
3944            "right-after-left over-crop must error like torch"
3945        );
3946    }
3947
3948    /// Net-zero crop is NOT an error in torch: `F.pad([[[1,2,3]]], [-1,-2])`
3949    /// returns an empty dim `[1,1,0]`. ferrotorch must match (no error).
3950    #[test]
3951    fn test_functional_pad_1d_signed_net_zero_empty_dim_matches_torch() {
3952        let x = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3953        let y = functional_pad_1d_signed(&x, -1, -2, PaddingMode::Zeros, 0.0).unwrap();
3954        assert_eq!(y.shape(), &[1, 1, 0]);
3955        assert!(y.data().unwrap().is_empty());
3956    }
3957
3958    /// Negative (crop) pad under a non-constant mode CROPS — live torch 2.11's
3959    /// `_pad_enum` dispatches reflect/replicate/circular straight to the native
3960    /// kernels, which narrow for negative pads (`PadNd.cpp:221-242`). For
3961    /// `[-1, 0]` on `[1,2,3,4]` all three modes crop the left element, yielding
3962    /// `[2,3,4]` (the positive part of the pad is zero, so it is a pure crop).
3963    /// torch: `F.pad([[[1.,2.,3.,4.]]], [-1,0], mode=<m>)` -> shape [1,1,3],
3964    /// `[2,3,4]` for reflect/replicate/circular alike (#1620).
3965    #[test]
3966    fn test_functional_pad_signed_negative_non_constant_crops() {
3967        let x = t(&[1.0, 2.0, 3.0, 4.0], &[1, 1, 4]);
3968        for mode in [
3969            PaddingMode::Reflect,
3970            PaddingMode::Replicate,
3971            PaddingMode::Circular,
3972        ] {
3973            let y = functional_pad_1d_signed(&x, -1, 0, mode, 0.0)
3974                .unwrap_or_else(|_| panic!("negative pad under {mode:?} must crop, not error"));
3975            assert_eq!(
3976                y.shape(),
3977                &[1, 1, 3],
3978                "{mode:?} crops left -> shape [1,1,3]"
3979            );
3980            assert_close(y.data().unwrap(), &[2.0, 3.0, 4.0], 1e-7);
3981        }
3982    }
3983
3984    /// A non-negative signed pad must be byte-identical to the existing
3985    /// positive-only `functional_pad_1d` (the delegation invariant that makes
3986    /// the signed path the single source of truth for constant padding without
3987    /// changing conv.rs's production behaviour). torch:
3988    /// `F.pad([[[1,2,3]]], [1,1], "constant", value=2)` -> [2,1,2,3,2].
3989    #[test]
3990    fn test_functional_pad_1d_signed_nonneg_equals_positive_path() {
3991        let input = t(&[1.0, 2.0, 3.0], &[1, 1, 3]);
3992        let signed = functional_pad_1d_signed(&input, 1, 1, PaddingMode::Zeros, 2.0).unwrap();
3993        let positive = functional_pad_1d(&input, 1, 1, PaddingMode::Zeros, 2.0).unwrap();
3994        assert_eq!(signed.shape(), positive.shape());
3995        assert_close(signed.data().unwrap(), positive.data().unwrap(), 1e-7);
3996        assert_close(signed.data().unwrap(), &[2.0, 1.0, 2.0, 3.0, 2.0], 1e-7);
3997    }
3998}