Skip to main content

ferrotorch_nn/
upsample.rs

1//! Upsample, interpolation, and vision ops.
2//!
3//! This module provides spatial resizing and transformation modules for
4//! vision workloads:
5//!
6//! - [`Upsample`] — Upsamples a `[B, C, H, W]` tensor using nearest, bilinear,
7//!   or bicubic interpolation.
8//! - [`PixelShuffle`] / [`PixelUnshuffle`] — Sub-pixel convolution for
9//!   efficient super-resolution (`[B, C*r*r, H, W]` <-> `[B, C, H*r, W*r]`).
10//! - [`Fold`] / [`Unfold`] — Sliding-window patch extraction and reconstruction.
11//!
12//! All autograd-tracked operations attach a `GradFn<T>` when gradient tracking
13//! is enabled so reverse-mode differentiation works out of the box.
14//!
15//! CL-317: Upsample, Interpolation & Vision Ops
16//!
17//! ## REQ status (per `.design/ferrotorch-nn/upsample.md`)
18//!
19//! | REQ | Status | Evidence |
20//! |---|---|---|
21//! | REQ-1 | SHIPPED | the `InterpolateMode` enum here; non-test consumer: re-export at `ferrotorch-nn/src/lib.rs:252` + `ferrotorch-vision/src/models/segmentation/deeplabv3.rs:52` + `aspp.rs:39` + `lraspp.rs:50` + `fcn.rs:36` |
22//! | REQ-2 | SHIPPED | the `GridSamplePaddingMode` and `GridSampleMode` enums here; non-test consumer: re-export at `lib.rs:252` |
23//! | REQ-3 | SHIPPED | the `interpolate<T>` entry here; non-test consumer: re-export at `lib.rs:252` + every segmentation model in `ferrotorch-vision` |
24//! | REQ-4 | SHIPPED | the `Upsample` struct + `impl<T: Float> Module<T> for Upsample` here; non-test consumer: re-export at `lib.rs:252` |
25//! | REQ-5 | SHIPPED | the `grid_sample<T>` entry here; non-test consumer: re-export at `lib.rs:252` |
26//! | REQ-6 | SHIPPED | the `affine_grid<T>` entry here; non-test consumer: re-export at `lib.rs:252` |
27//! | REQ-7 | SHIPPED | the `PixelShuffle` + `PixelUnshuffle` structs + their `impl<T: Float> Module<T>` blocks here; non-test consumer: re-export at `lib.rs:252` |
28//! | REQ-8 | SHIPPED | the `pixel_shuffle<T>` + `pixel_unshuffle<T>` functional entries here; non-test consumer: re-export at `lib.rs:252` |
29//! | REQ-9 | SHIPPED | the `Unfold` + `Fold` structs + their `impl<T: Float> Module<T>` blocks and `unfold` + `fold` functional entries here; non-test consumer: re-export at `lib.rs:252` |
30//! | REQ-10 | SHIPPED | the `cubic_weight` helper here; non-test consumer: invoked from the bicubic branch of `interpolate` (re-exported at `lib.rs:252`) |
31//! | REQ-11 | SHIPPED | `align_corners`-aware source-coordinate formulas inside the interpolate / grid_sample bodies here; non-test consumer: re-export at `lib.rs:252` |
32//! | REQ-12 | SHIPPED | per-op `GradFn<T>` types + `Tensor::from_operation` calls here; non-test consumer: re-export at `lib.rs:252` |
33
34use std::sync::Arc;
35
36use ferrotorch_core::autograd::no_grad::is_grad_enabled;
37use ferrotorch_core::tensor::GradFn;
38use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
39
40use crate::module::Module;
41use crate::parameter::Parameter;
42
43// ===========================================================================
44// Interpolation mode
45// ===========================================================================
46
47/// Interpolation mode used by [`Upsample`] and [`interpolate`].
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum InterpolateMode {
50    /// Nearest-neighbor interpolation.
51    Nearest,
52    /// Bilinear interpolation (4-neighbor weighted average).
53    Bilinear,
54    /// Bicubic interpolation (16-neighbor cubic kernel).
55    Bicubic,
56}
57
58/// Padding mode for [`grid_sample`].
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
60pub enum GridSamplePaddingMode {
61    /// Zero-pad outside the input boundary.
62    Zeros,
63    /// Clamp coordinates to the border of the input.
64    Border,
65    /// Reflect coordinates at the border.
66    Reflection,
67}
68
69/// Sampling mode for [`grid_sample`].
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
71pub enum GridSampleMode {
72    /// Bilinear interpolation.
73    Bilinear,
74    /// Nearest-neighbor sampling.
75    Nearest,
76}
77
78// ===========================================================================
79// Helpers
80// ===========================================================================
81
82/// Validate that the input tensor has shape `[B, C, H, W]`.
83fn validate_4d<T: Float>(
84    input: &Tensor<T>,
85    fn_name: &str,
86) -> FerrotorchResult<(usize, usize, usize, usize)> {
87    let shape = input.shape();
88    if shape.len() != 4 {
89        return Err(FerrotorchError::InvalidArgument {
90            message: format!(
91                "{fn_name} expects 4D input [B, C, H, W], got shape {:?}",
92                shape
93            ),
94        });
95    }
96    Ok((shape[0], shape[1], shape[2], shape[3]))
97}
98
99/// Cubic interpolation kernel (Keys' cubic, a = -0.75).
100///
101/// For a distance `t` from the pixel center, this returns the weight
102/// used in bicubic interpolation. All arithmetic is done in `f64`.
103#[inline]
104fn cubic_weight(t: f64) -> f64 {
105    let abs_t = t.abs();
106    let a: f64 = -0.75;
107
108    if abs_t <= 1.0 {
109        (a + 2.0) * abs_t * abs_t * abs_t - (a + 3.0) * abs_t * abs_t + 1.0
110    } else if abs_t < 2.0 {
111        a * abs_t * abs_t * abs_t - 5.0 * a * abs_t * abs_t + 8.0 * a * abs_t - 4.0 * a
112    } else {
113        0.0
114    }
115}
116
117/// Compute source coordinate for `align_corners=true`.
118///
119/// Maps the output index `i` in `[0, out_size-1]` to the input space `[0, in_size-1]`
120/// using a linear mapping that aligns the corners.
121#[inline]
122fn align_corners_coord(i: usize, in_size: usize, out_size: usize) -> f64 {
123    if out_size <= 1 {
124        return 0.0;
125    }
126    (i as f64) * ((in_size - 1) as f64) / ((out_size - 1) as f64)
127}
128
129/// Compute source coordinate for `align_corners=false`.
130///
131/// Uses the half-pixel convention: map center of output pixel to input space.
132#[inline]
133fn half_pixel_coord(i: usize, in_size: usize, out_size: usize) -> f64 {
134    (i as f64 + 0.5) * (in_size as f64 / out_size as f64) - 0.5
135}
136
137/// Clamp a value to `[0, max]`.
138#[inline]
139fn clamp_coord(val: isize, max: usize) -> usize {
140    if val < 0 {
141        0
142    } else if val as usize > max {
143        max
144    } else {
145        val as usize
146    }
147}
148
149// ===========================================================================
150// interpolate — functional API
151// ===========================================================================
152
153/// Interpolation target size. Exactly one of `size` or `scale_factor` must be
154/// provided (the other is `None`).
155///
156/// CL-317
157pub fn interpolate<T: Float>(
158    input: &Tensor<T>,
159    size: Option<[usize; 2]>,
160    scale_factor: Option<[f64; 2]>,
161    mode: InterpolateMode,
162    align_corners: bool,
163) -> FerrotorchResult<Tensor<T>> {
164    let (batch, channels, h_in, w_in) = validate_4d(input, "interpolate")?;
165
166    // Resolve target size.
167    let (h_out, w_out) = match (size, scale_factor) {
168        (Some(s), None) => (s[0], s[1]),
169        (None, Some(sf)) => {
170            let h = (h_in as f64 * sf[0]).round() as usize;
171            let w = (w_in as f64 * sf[1]).round() as usize;
172            if h == 0 || w == 0 {
173                return Err(FerrotorchError::InvalidArgument {
174                    message: format!(
175                        "interpolate: scale_factor {sf:?} with input ({h_in}, {w_in}) produces zero output"
176                    ),
177                });
178            }
179            (h, w)
180        }
181        _ => {
182            return Err(FerrotorchError::InvalidArgument {
183                message: "interpolate: exactly one of size or scale_factor must be provided".into(),
184            });
185        }
186    };
187
188    if h_out == 0 || w_out == 0 {
189        return Err(FerrotorchError::InvalidArgument {
190            message: format!("interpolate: output size ({h_out}, {w_out}) must be > 0"),
191        });
192    }
193
194    if mode == InterpolateMode::Nearest && align_corners {
195        return Err(FerrotorchError::InvalidArgument {
196            message: "interpolate: align_corners is not supported with nearest mode".into(),
197        });
198    }
199
200    let input_device = input.device();
201    let data = input.data_vec()?;
202
203    let total = batch * channels * h_out * w_out;
204    let mut output = vec![T::from(0.0).unwrap(); total];
205
206    match mode {
207        InterpolateMode::Nearest => {
208            nearest_forward(
209                &data,
210                &mut output,
211                batch,
212                channels,
213                h_in,
214                w_in,
215                h_out,
216                w_out,
217            );
218        }
219        InterpolateMode::Bilinear => {
220            bilinear_forward(
221                &data,
222                &mut output,
223                batch,
224                channels,
225                h_in,
226                w_in,
227                h_out,
228                w_out,
229                align_corners,
230            );
231        }
232        InterpolateMode::Bicubic => {
233            bicubic_forward(
234                &data,
235                &mut output,
236                batch,
237                channels,
238                h_in,
239                w_in,
240                h_out,
241                w_out,
242                align_corners,
243            );
244        }
245    }
246
247    let out_shape = vec![batch, channels, h_out, w_out];
248    let storage = TensorStorage::cpu(output);
249
250    if is_grad_enabled() && input.requires_grad() {
251        Tensor::from_operation(
252            storage,
253            out_shape,
254            Arc::new(InterpolateBackward {
255                input: input.clone(),
256                h_out,
257                w_out,
258                mode,
259                align_corners,
260            }),
261        )?
262        .to(input_device)
263    } else {
264        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
265    }
266}
267
268// ---------------------------------------------------------------------------
269// Forward kernels
270// ---------------------------------------------------------------------------
271
272// Internal kernel: argument set is the upsample descriptor
273// (B, C, H_in, W_in, H_out, W_out, scale_h, scale_w); a config struct
274// would force allocation in the hot interpolate path.
275#[allow(clippy::too_many_arguments)]
276fn nearest_forward<T: Float>(
277    data: &[T],
278    output: &mut [T],
279    batch: usize,
280    channels: usize,
281    h_in: usize,
282    w_in: usize,
283    h_out: usize,
284    w_out: usize,
285) {
286    let h_scale = h_in as f64 / h_out as f64;
287    let w_scale = w_in as f64 / w_out as f64;
288
289    for b in 0..batch {
290        for c in 0..channels {
291            for oh in 0..h_out {
292                let ih = ((oh as f64 * h_scale).floor() as usize).min(h_in - 1);
293                for ow in 0..w_out {
294                    let iw = ((ow as f64 * w_scale).floor() as usize).min(w_in - 1);
295                    let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
296                    let in_idx = ((b * channels + c) * h_in + ih) * w_in + iw;
297                    output[out_idx] = data[in_idx];
298                }
299            }
300        }
301    }
302}
303
304// Internal kernel: same upsample descriptor as `nearest_forward`.
305#[allow(clippy::too_many_arguments)]
306fn bilinear_forward<T: Float>(
307    data: &[T],
308    output: &mut [T],
309    batch: usize,
310    channels: usize,
311    h_in: usize,
312    w_in: usize,
313    h_out: usize,
314    w_out: usize,
315    align_corners: bool,
316) {
317    let one = T::from(1.0).unwrap();
318
319    for b in 0..batch {
320        for c in 0..channels {
321            for oh in 0..h_out {
322                let src_h = if align_corners {
323                    align_corners_coord(oh, h_in, h_out)
324                } else {
325                    half_pixel_coord(oh, h_in, h_out)
326                };
327
328                let h0 = src_h.floor() as isize;
329                let h1 = h0 + 1;
330                let th = T::from(src_h - h0 as f64).unwrap();
331
332                for ow in 0..w_out {
333                    let src_w = if align_corners {
334                        align_corners_coord(ow, w_in, w_out)
335                    } else {
336                        half_pixel_coord(ow, w_in, w_out)
337                    };
338
339                    let w0 = src_w.floor() as isize;
340                    let w1 = w0 + 1;
341                    let tw = T::from(src_w - w0 as f64).unwrap();
342
343                    let ch0 = clamp_coord(h0, h_in - 1);
344                    let ch1 = clamp_coord(h1, h_in - 1);
345                    let cw0 = clamp_coord(w0, w_in - 1);
346                    let cw1 = clamp_coord(w1, w_in - 1);
347
348                    let base = (b * channels + c) * h_in;
349                    let v00 = data[(base + ch0) * w_in + cw0];
350                    let v01 = data[(base + ch0) * w_in + cw1];
351                    let v10 = data[(base + ch1) * w_in + cw0];
352                    let v11 = data[(base + ch1) * w_in + cw1];
353
354                    let val = v00 * (one - th) * (one - tw)
355                        + v01 * (one - th) * tw
356                        + v10 * th * (one - tw)
357                        + v11 * th * tw;
358
359                    let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
360                    output[out_idx] = val;
361                }
362            }
363        }
364    }
365}
366
367// Internal kernel: same upsample descriptor as `nearest_forward`.
368#[allow(clippy::too_many_arguments)]
369fn bicubic_forward<T: Float>(
370    data: &[T],
371    output: &mut [T],
372    batch: usize,
373    channels: usize,
374    h_in: usize,
375    w_in: usize,
376    h_out: usize,
377    w_out: usize,
378    align_corners: bool,
379) {
380    for b in 0..batch {
381        for c in 0..channels {
382            for oh in 0..h_out {
383                let src_h = if align_corners {
384                    align_corners_coord(oh, h_in, h_out)
385                } else {
386                    half_pixel_coord(oh, h_in, h_out)
387                };
388
389                let h_floor = src_h.floor() as isize;
390                let frac_h = src_h - h_floor as f64;
391
392                // Precompute 4 vertical kernel weights.
393                let wh: [T; 4] = [
394                    T::from(cubic_weight(frac_h + 1.0)).unwrap(),
395                    T::from(cubic_weight(frac_h)).unwrap(),
396                    T::from(cubic_weight(frac_h - 1.0)).unwrap(),
397                    T::from(cubic_weight(frac_h - 2.0)).unwrap(),
398                ];
399
400                for ow in 0..w_out {
401                    let src_w = if align_corners {
402                        align_corners_coord(ow, w_in, w_out)
403                    } else {
404                        half_pixel_coord(ow, w_in, w_out)
405                    };
406
407                    let w_floor = src_w.floor() as isize;
408                    let frac_w = src_w - w_floor as f64;
409
410                    let ww: [T; 4] = [
411                        T::from(cubic_weight(frac_w + 1.0)).unwrap(),
412                        T::from(cubic_weight(frac_w)).unwrap(),
413                        T::from(cubic_weight(frac_w - 1.0)).unwrap(),
414                        T::from(cubic_weight(frac_w - 2.0)).unwrap(),
415                    ];
416
417                    let mut val = T::from(0.0).unwrap();
418                    let base = (b * channels + c) * h_in;
419
420                    for dy in 0..4isize {
421                        let iy = clamp_coord(h_floor - 1 + dy, h_in - 1);
422                        for dx in 0..4isize {
423                            let ix = clamp_coord(w_floor - 1 + dx, w_in - 1);
424                            let pixel = data[(base + iy) * w_in + ix];
425                            val += pixel * wh[dy as usize] * ww[dx as usize];
426                        }
427                    }
428
429                    let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
430                    output[out_idx] = val;
431                }
432            }
433        }
434    }
435}
436
437// ---------------------------------------------------------------------------
438// Backward for interpolate
439// ---------------------------------------------------------------------------
440
441#[derive(Debug)]
442struct InterpolateBackward<T: Float> {
443    input: Tensor<T>,
444    h_out: usize,
445    w_out: usize,
446    mode: InterpolateMode,
447    align_corners: bool,
448}
449
450impl<T: Float> GradFn<T> for InterpolateBackward<T> {
451    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
452        if !self.input.requires_grad() {
453            return Ok(vec![None]);
454        }
455
456        let in_shape = self.input.shape();
457        let (batch, channels, h_in, w_in) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
458        let h_out = self.h_out;
459        let w_out = self.w_out;
460
461        let go_data = grad_output.data_vec()?;
462        let mut grad_input = vec![T::from(0.0).unwrap(); batch * channels * h_in * w_in];
463
464        match self.mode {
465            InterpolateMode::Nearest => {
466                nearest_backward(
467                    &go_data,
468                    &mut grad_input,
469                    batch,
470                    channels,
471                    h_in,
472                    w_in,
473                    h_out,
474                    w_out,
475                );
476            }
477            InterpolateMode::Bilinear => {
478                bilinear_backward(
479                    &go_data,
480                    &mut grad_input,
481                    batch,
482                    channels,
483                    h_in,
484                    w_in,
485                    h_out,
486                    w_out,
487                    self.align_corners,
488                );
489            }
490            InterpolateMode::Bicubic => {
491                bicubic_backward(
492                    &go_data,
493                    &mut grad_input,
494                    batch,
495                    channels,
496                    h_in,
497                    w_in,
498                    h_out,
499                    w_out,
500                    self.align_corners,
501                );
502            }
503        }
504
505        let grad_tensor = Tensor::from_storage(
506            TensorStorage::cpu(grad_input),
507            self.input.shape().to_vec(),
508            false,
509        )?;
510        Ok(vec![Some(grad_tensor)])
511    }
512
513    fn inputs(&self) -> Vec<&Tensor<T>> {
514        vec![&self.input]
515    }
516
517    fn name(&self) -> &'static str {
518        "InterpolateBackward"
519    }
520}
521
522// Internal kernel: adjoint of `nearest_forward`; same descriptor.
523#[allow(clippy::too_many_arguments)]
524fn nearest_backward<T: Float>(
525    go: &[T],
526    grad_input: &mut [T],
527    batch: usize,
528    channels: usize,
529    h_in: usize,
530    w_in: usize,
531    h_out: usize,
532    w_out: usize,
533) {
534    let h_scale = h_in as f64 / h_out as f64;
535    let w_scale = w_in as f64 / w_out as f64;
536
537    for b in 0..batch {
538        for c in 0..channels {
539            for oh in 0..h_out {
540                let ih = ((oh as f64 * h_scale).floor() as usize).min(h_in - 1);
541                for ow in 0..w_out {
542                    let iw = ((ow as f64 * w_scale).floor() as usize).min(w_in - 1);
543                    let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
544                    let in_idx = ((b * channels + c) * h_in + ih) * w_in + iw;
545                    grad_input[in_idx] += go[out_idx];
546                }
547            }
548        }
549    }
550}
551
552// Internal kernel: adjoint of `bilinear_forward`; same descriptor.
553#[allow(clippy::too_many_arguments)]
554fn bilinear_backward<T: Float>(
555    go: &[T],
556    grad_input: &mut [T],
557    batch: usize,
558    channels: usize,
559    h_in: usize,
560    w_in: usize,
561    h_out: usize,
562    w_out: usize,
563    align_corners: bool,
564) {
565    let one = T::from(1.0).unwrap();
566
567    for b in 0..batch {
568        for c in 0..channels {
569            for oh in 0..h_out {
570                let src_h = if align_corners {
571                    align_corners_coord(oh, h_in, h_out)
572                } else {
573                    half_pixel_coord(oh, h_in, h_out)
574                };
575
576                let h0 = src_h.floor() as isize;
577                let h1 = h0 + 1;
578                let th = T::from(src_h - h0 as f64).unwrap();
579
580                for ow in 0..w_out {
581                    let src_w = if align_corners {
582                        align_corners_coord(ow, w_in, w_out)
583                    } else {
584                        half_pixel_coord(ow, w_in, w_out)
585                    };
586
587                    let w0 = src_w.floor() as isize;
588                    let w1 = w0 + 1;
589                    let tw = T::from(src_w - w0 as f64).unwrap();
590
591                    let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
592                    let g = go[out_idx];
593
594                    let ch0 = clamp_coord(h0, h_in - 1);
595                    let ch1 = clamp_coord(h1, h_in - 1);
596                    let cw0 = clamp_coord(w0, w_in - 1);
597                    let cw1 = clamp_coord(w1, w_in - 1);
598
599                    let base = (b * channels + c) * h_in;
600
601                    grad_input[(base + ch0) * w_in + cw0] += g * (one - th) * (one - tw);
602                    grad_input[(base + ch0) * w_in + cw1] += g * (one - th) * tw;
603                    grad_input[(base + ch1) * w_in + cw0] += g * th * (one - tw);
604                    grad_input[(base + ch1) * w_in + cw1] += g * th * tw;
605                }
606            }
607        }
608    }
609}
610
611// Internal kernel: adjoint of `bicubic_forward`; same descriptor.
612#[allow(clippy::too_many_arguments)]
613fn bicubic_backward<T: Float>(
614    go: &[T],
615    grad_input: &mut [T],
616    batch: usize,
617    channels: usize,
618    h_in: usize,
619    w_in: usize,
620    h_out: usize,
621    w_out: usize,
622    align_corners: bool,
623) {
624    for b in 0..batch {
625        for c in 0..channels {
626            for oh in 0..h_out {
627                let src_h: f64 = if align_corners {
628                    align_corners_coord(oh, h_in, h_out)
629                } else {
630                    half_pixel_coord(oh, h_in, h_out)
631                };
632
633                let h_floor = src_h.floor() as isize;
634                let frac_h = src_h - h_floor as f64;
635
636                let wh: [T; 4] = [
637                    T::from(cubic_weight(frac_h + 1.0)).unwrap(),
638                    T::from(cubic_weight(frac_h)).unwrap(),
639                    T::from(cubic_weight(frac_h - 1.0)).unwrap(),
640                    T::from(cubic_weight(frac_h - 2.0)).unwrap(),
641                ];
642
643                for ow in 0..w_out {
644                    let src_w: f64 = if align_corners {
645                        align_corners_coord(ow, w_in, w_out)
646                    } else {
647                        half_pixel_coord(ow, w_in, w_out)
648                    };
649
650                    let w_floor = src_w.floor() as isize;
651                    let frac_w = src_w - w_floor as f64;
652
653                    let ww: [T; 4] = [
654                        T::from(cubic_weight(frac_w + 1.0)).unwrap(),
655                        T::from(cubic_weight(frac_w)).unwrap(),
656                        T::from(cubic_weight(frac_w - 1.0)).unwrap(),
657                        T::from(cubic_weight(frac_w - 2.0)).unwrap(),
658                    ];
659
660                    let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
661                    let g = go[out_idx];
662                    let base = (b * channels + c) * h_in;
663
664                    for dy in 0..4isize {
665                        let iy = clamp_coord(h_floor - 1 + dy, h_in - 1);
666                        for dx in 0..4isize {
667                            let ix = clamp_coord(w_floor - 1 + dx, w_in - 1);
668                            grad_input[(base + iy) * w_in + ix] +=
669                                g * wh[dy as usize] * ww[dx as usize];
670                        }
671                    }
672                }
673            }
674        }
675    }
676}
677
678// ===========================================================================
679// Upsample module
680// ===========================================================================
681
682/// Upsamples a `[B, C, H, W]` tensor to a target spatial size.
683///
684/// Supports nearest, bilinear, and bicubic interpolation. This is the
685/// module-based wrapper around [`interpolate`].
686///
687/// CL-317
688#[derive(Debug, Clone)]
689pub struct Upsample {
690    /// Target output spatial size `[H, W]`. If `None`, `scale_factor` is used.
691    pub size: Option<[usize; 2]>,
692    /// Scaling factor `[scale_h, scale_w]`. If `None`, `size` is used.
693    pub scale_factor: Option<[f64; 2]>,
694    /// Interpolation mode.
695    pub mode: InterpolateMode,
696    /// Whether to align corners (bilinear/bicubic only).
697    pub align_corners: bool,
698}
699
700impl Upsample {
701    /// Create a new `Upsample` with target `size`.
702    pub fn new(size: [usize; 2], mode: InterpolateMode) -> Self {
703        Self {
704            size: Some(size),
705            scale_factor: None,
706            mode,
707            align_corners: false,
708        }
709    }
710
711    /// Create a new `Upsample` with a `scale_factor`.
712    pub fn with_scale_factor(scale_factor: [f64; 2], mode: InterpolateMode) -> Self {
713        Self {
714            size: None,
715            scale_factor: Some(scale_factor),
716            mode,
717            align_corners: false,
718        }
719    }
720
721    /// Set `align_corners` (meaningful for bilinear/bicubic only).
722    pub fn align_corners(mut self, align_corners: bool) -> Self {
723        self.align_corners = align_corners;
724        self
725    }
726}
727
728impl<T: Float> Module<T> for Upsample {
729    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
730        interpolate(
731            input,
732            self.size,
733            self.scale_factor,
734            self.mode,
735            self.align_corners,
736        )
737    }
738
739    fn parameters(&self) -> Vec<&Parameter<T>> {
740        vec![]
741    }
742
743    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
744        vec![]
745    }
746
747    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
748        vec![]
749    }
750
751    fn train(&mut self) {}
752    fn eval(&mut self) {}
753
754    fn is_training(&self) -> bool {
755        false
756    }
757}
758
759// ===========================================================================
760// grid_sample
761// ===========================================================================
762
763/// Samples `input` at spatial locations specified by `grid`.
764///
765/// This implements the spatial transformer network sampling operation.
766///
767/// # Shapes
768///
769/// - `input`: `[B, C, H_in, W_in]`
770/// - `grid`: `[B, H_out, W_out, 2]` — normalized coordinates in `[-1, 1]`
771/// - **returns**: `[B, C, H_out, W_out]`
772///
773/// CL-317
774pub fn grid_sample<T: Float>(
775    input: &Tensor<T>,
776    grid: &Tensor<T>,
777    mode: GridSampleMode,
778    padding_mode: GridSamplePaddingMode,
779    align_corners: bool,
780) -> FerrotorchResult<Tensor<T>> {
781    let (batch, channels, h_in, w_in) = validate_4d(input, "grid_sample")?;
782
783    let grid_shape = grid.shape();
784    if grid_shape.len() != 4 || grid_shape[3] != 2 {
785        return Err(FerrotorchError::InvalidArgument {
786            message: format!(
787                "grid_sample: grid must be [B, H_out, W_out, 2], got {:?}",
788                grid_shape
789            ),
790        });
791    }
792    if grid_shape[0] != batch {
793        return Err(FerrotorchError::ShapeMismatch {
794            message: format!(
795                "grid_sample: batch mismatch between input ({batch}) and grid ({})",
796                grid_shape[0]
797            ),
798        });
799    }
800    let h_out = grid_shape[1];
801    let w_out = grid_shape[2];
802
803    let input_device = input.device();
804    let in_data = input.data_vec()?;
805    let grid_data = grid.data_vec()?;
806
807    let total = batch * channels * h_out * w_out;
808    let mut output = vec![T::from(0.0).unwrap(); total];
809
810    let one = T::from(1.0).unwrap();
811    let two = T::from(2.0).unwrap();
812    let zero = T::from(0.0).unwrap();
813
814    for b in 0..batch {
815        for oh in 0..h_out {
816            for ow in 0..w_out {
817                let grid_base = ((b * h_out + oh) * w_out + ow) * 2;
818                let gx = grid_data[grid_base]; // normalized x
819                let gy = grid_data[grid_base + 1]; // normalized y
820
821                // Denormalize grid coordinates from [-1, 1] to pixel space.
822                let (src_x, src_y) = if align_corners {
823                    let sx = (gx + one) * T::from(w_in - 1).unwrap() / two;
824                    let sy = (gy + one) * T::from(h_in - 1).unwrap() / two;
825                    (sx, sy)
826                } else {
827                    let sx = ((gx + one) * T::from(w_in).unwrap() - one) / two;
828                    let sy = ((gy + one) * T::from(h_in).unwrap() - one) / two;
829                    (sx, sy)
830                };
831
832                for c in 0..channels {
833                    let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
834                    let in_base = (b * channels + c) * h_in;
835
836                    match mode {
837                        GridSampleMode::Nearest => {
838                            let ix = src_x.to_f64().unwrap().round() as isize;
839                            let iy = src_y.to_f64().unwrap().round() as isize;
840                            let (ix, iy) = apply_padding_mode(ix, iy, w_in, h_in, padding_mode);
841
842                            if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
843                                output[out_idx] =
844                                    in_data[(in_base + iy as usize) * w_in + ix as usize];
845                            }
846                            // else stays zero (for Zeros padding)
847                        }
848                        GridSampleMode::Bilinear => {
849                            let sx = src_x.to_f64().unwrap();
850                            let sy = src_y.to_f64().unwrap();
851                            let x0 = sx.floor() as isize;
852                            let y0 = sy.floor() as isize;
853                            let x1 = x0 + 1;
854                            let y1 = y0 + 1;
855                            let tx = T::from(sx - x0 as f64).unwrap();
856                            let ty = T::from(sy - y0 as f64).unwrap();
857
858                            let get_pixel = |iy: isize, ix: isize| -> T {
859                                let (ix, iy) = apply_padding_mode(ix, iy, w_in, h_in, padding_mode);
860                                if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
861                                    in_data[(in_base + iy as usize) * w_in + ix as usize]
862                                } else {
863                                    zero
864                                }
865                            };
866
867                            let v00 = get_pixel(y0, x0);
868                            let v01 = get_pixel(y0, x1);
869                            let v10 = get_pixel(y1, x0);
870                            let v11 = get_pixel(y1, x1);
871
872                            output[out_idx] = v00 * (one - ty) * (one - tx)
873                                + v01 * (one - ty) * tx
874                                + v10 * ty * (one - tx)
875                                + v11 * ty * tx;
876                        }
877                    }
878                }
879            }
880        }
881    }
882
883    let out_shape = vec![batch, channels, h_out, w_out];
884    let storage = TensorStorage::cpu(output);
885
886    if is_grad_enabled() && (input.requires_grad() || grid.requires_grad()) {
887        Tensor::from_operation(
888            storage,
889            out_shape,
890            Arc::new(GridSampleBackward {
891                input: input.clone(),
892                grid: grid.clone(),
893                mode,
894                padding_mode,
895                align_corners,
896            }),
897        )?
898        .to(input_device)
899    } else {
900        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
901    }
902}
903
904/// Apply padding mode to grid coordinates.
905fn apply_padding_mode(
906    ix: isize,
907    iy: isize,
908    w: usize,
909    h: usize,
910    padding_mode: GridSamplePaddingMode,
911) -> (isize, isize) {
912    match padding_mode {
913        GridSamplePaddingMode::Zeros => (ix, iy),
914        GridSamplePaddingMode::Border => {
915            let cx = ix.max(0).min(w as isize - 1);
916            let cy = iy.max(0).min(h as isize - 1);
917            (cx, cy)
918        }
919        GridSamplePaddingMode::Reflection => {
920            let reflect = |v: isize, size: usize| -> isize {
921                if size <= 1 {
922                    return 0;
923                }
924                let max = size as isize - 1;
925                let mut v = v;
926                if v < 0 {
927                    v = -v;
928                }
929                // Fold via period 2*(size-1)
930                let period = 2 * max;
931                v %= period;
932                if v > max {
933                    v = period - v;
934                }
935                v
936            };
937            (reflect(ix, w), reflect(iy, h))
938        }
939    }
940}
941
942// ---------------------------------------------------------------------------
943// GridSample backward
944// ---------------------------------------------------------------------------
945
946#[derive(Debug)]
947struct GridSampleBackward<T: Float> {
948    input: Tensor<T>,
949    grid: Tensor<T>,
950    mode: GridSampleMode,
951    padding_mode: GridSamplePaddingMode,
952    align_corners: bool,
953}
954
955impl<T: Float> GradFn<T> for GridSampleBackward<T> {
956    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
957        let in_shape = self.input.shape();
958        let (batch, channels, h_in, w_in) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
959        let grid_shape = self.grid.shape();
960        let h_out = grid_shape[1];
961        let w_out = grid_shape[2];
962
963        let go_data = grad_output.data_vec()?;
964        let in_data = self.input.data_vec()?;
965        let grid_data = self.grid.data_vec()?;
966
967        let one = T::from(1.0).unwrap();
968        let two = T::from(2.0).unwrap();
969        let zero = T::from(0.0).unwrap();
970
971        let grad_input_needed = self.input.requires_grad();
972        let grad_grid_needed = self.grid.requires_grad();
973
974        let mut grad_input = if grad_input_needed {
975            vec![zero; batch * channels * h_in * w_in]
976        } else {
977            vec![]
978        };
979        let mut grad_grid = if grad_grid_needed {
980            vec![zero; batch * h_out * w_out * 2]
981        } else {
982            vec![]
983        };
984
985        for b in 0..batch {
986            for oh in 0..h_out {
987                for ow in 0..w_out {
988                    let grid_base = ((b * h_out + oh) * w_out + ow) * 2;
989                    let gx = grid_data[grid_base];
990                    let gy = grid_data[grid_base + 1];
991
992                    let (src_x, src_y) = if self.align_corners {
993                        let sx = (gx + one) * T::from(w_in - 1).unwrap() / two;
994                        let sy = (gy + one) * T::from(h_in - 1).unwrap() / two;
995                        (sx, sy)
996                    } else {
997                        let sx = ((gx + one) * T::from(w_in).unwrap() - one) / two;
998                        let sy = ((gy + one) * T::from(h_in).unwrap() - one) / two;
999                        (sx, sy)
1000                    };
1001
1002                    match self.mode {
1003                        GridSampleMode::Bilinear => {
1004                            let sx = src_x.to_f64().unwrap();
1005                            let sy = src_y.to_f64().unwrap();
1006                            let x0 = sx.floor() as isize;
1007                            let y0 = sy.floor() as isize;
1008                            let x1 = x0 + 1;
1009                            let y1 = y0 + 1;
1010                            let tx = T::from(sx - x0 as f64).unwrap();
1011                            let ty = T::from(sy - y0 as f64).unwrap();
1012
1013                            let get_clamped = |iy: isize, ix: isize| -> (isize, isize) {
1014                                apply_padding_mode(ix, iy, w_in, h_in, self.padding_mode)
1015                            };
1016
1017                            for c in 0..channels {
1018                                let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
1019                                let g = go_data[out_idx];
1020                                let in_base = (b * channels + c) * h_in;
1021
1022                                // Gradient w.r.t. input
1023                                if grad_input_needed {
1024                                    let coords = [
1025                                        (y0, x0, (one - ty) * (one - tx)),
1026                                        (y0, x1, (one - ty) * tx),
1027                                        (y1, x0, ty * (one - tx)),
1028                                        (y1, x1, ty * tx),
1029                                    ];
1030                                    for (iy, ix, w) in coords {
1031                                        let (ix, iy) = get_clamped(iy, ix);
1032                                        if ix >= 0
1033                                            && ix < w_in as isize
1034                                            && iy >= 0
1035                                            && iy < h_in as isize
1036                                        {
1037                                            grad_input
1038                                                [(in_base + iy as usize) * w_in + ix as usize] +=
1039                                                g * w;
1040                                        }
1041                                    }
1042                                }
1043
1044                                // Gradient w.r.t. grid
1045                                if grad_grid_needed {
1046                                    let get_pixel = |iy: isize, ix: isize| -> T {
1047                                        let (ix, iy) = get_clamped(iy, ix);
1048                                        if ix >= 0
1049                                            && ix < w_in as isize
1050                                            && iy >= 0
1051                                            && iy < h_in as isize
1052                                        {
1053                                            in_data[(in_base + iy as usize) * w_in + ix as usize]
1054                                        } else {
1055                                            zero
1056                                        }
1057                                    };
1058
1059                                    let v00 = get_pixel(y0, x0);
1060                                    let v01 = get_pixel(y0, x1);
1061                                    let v10 = get_pixel(y1, x0);
1062                                    let v11 = get_pixel(y1, x1);
1063
1064                                    // dout/d(src_x) = (1-ty)*(v01-v00) + ty*(v11-v10)
1065                                    let dout_dsx = (one - ty) * (v01 - v00) + ty * (v11 - v10);
1066                                    // dout/d(src_y) = (1-tx)*(v10-v00) + tx*(v11-v01)
1067                                    let dout_dsy = (one - tx) * (v10 - v00) + tx * (v11 - v01);
1068
1069                                    // d(src_x)/d(gx)
1070                                    let dsx_dgx = if self.align_corners {
1071                                        T::from(w_in - 1).unwrap() / two
1072                                    } else {
1073                                        T::from(w_in).unwrap() / two
1074                                    };
1075                                    let dsy_dgy = if self.align_corners {
1076                                        T::from(h_in - 1).unwrap() / two
1077                                    } else {
1078                                        T::from(h_in).unwrap() / two
1079                                    };
1080
1081                                    grad_grid[grid_base] += g * dout_dsx * dsx_dgx;
1082                                    grad_grid[grid_base + 1] += g * dout_dsy * dsy_dgy;
1083                                }
1084                            }
1085                        }
1086                        GridSampleMode::Nearest => {
1087                            // Nearest has zero gradient w.r.t. grid coordinates.
1088                            // Only accumulate gradient for input.
1089                            if grad_input_needed {
1090                                let ix = src_x.to_f64().unwrap().round() as isize;
1091                                let iy = src_y.to_f64().unwrap().round() as isize;
1092                                let (ix, iy) =
1093                                    apply_padding_mode(ix, iy, w_in, h_in, self.padding_mode);
1094
1095                                if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
1096                                    for c in 0..channels {
1097                                        let out_idx =
1098                                            ((b * channels + c) * h_out + oh) * w_out + ow;
1099                                        let in_base = (b * channels + c) * h_in;
1100                                        grad_input[(in_base + iy as usize) * w_in + ix as usize] +=
1101                                            go_data[out_idx];
1102                                    }
1103                                }
1104                            }
1105                        }
1106                    }
1107                }
1108            }
1109        }
1110
1111        let gi = if grad_input_needed {
1112            Some(Tensor::from_storage(
1113                TensorStorage::cpu(grad_input),
1114                self.input.shape().to_vec(),
1115                false,
1116            )?)
1117        } else {
1118            None
1119        };
1120
1121        let gg = if grad_grid_needed {
1122            Some(Tensor::from_storage(
1123                TensorStorage::cpu(grad_grid),
1124                self.grid.shape().to_vec(),
1125                false,
1126            )?)
1127        } else {
1128            None
1129        };
1130
1131        Ok(vec![gi, gg])
1132    }
1133
1134    fn inputs(&self) -> Vec<&Tensor<T>> {
1135        vec![&self.input, &self.grid]
1136    }
1137
1138    fn name(&self) -> &'static str {
1139        "GridSampleBackward"
1140    }
1141}
1142
1143// ===========================================================================
1144// affine_grid
1145// ===========================================================================
1146
1147/// Generate a 2D affine grid for use with [`grid_sample`].
1148///
1149/// # Shapes
1150///
1151/// - `theta`: `[B, 2, 3]` — 2D affine transformation matrices.
1152/// - `size`: `[B, C, H, W]` — the target output size.
1153/// - **returns**: `[B, H, W, 2]` — normalized grid coordinates.
1154///
1155/// CL-317
1156pub fn affine_grid<T: Float>(
1157    theta: &Tensor<T>,
1158    size: [usize; 4],
1159    align_corners: bool,
1160) -> FerrotorchResult<Tensor<T>> {
1161    let theta_shape = theta.shape();
1162    if theta_shape.len() != 3 || theta_shape[1] != 2 || theta_shape[2] != 3 {
1163        return Err(FerrotorchError::InvalidArgument {
1164            message: format!(
1165                "affine_grid: theta must be [B, 2, 3], got {:?}",
1166                theta_shape
1167            ),
1168        });
1169    }
1170    let batch = theta_shape[0];
1171    if size[0] != batch {
1172        return Err(FerrotorchError::ShapeMismatch {
1173            message: format!(
1174                "affine_grid: batch mismatch: theta batch {batch}, size batch {}",
1175                size[0]
1176            ),
1177        });
1178    }
1179
1180    let h = size[2];
1181    let w = size[3];
1182    let one = T::from(1.0).unwrap();
1183    let two = T::from(2.0).unwrap();
1184
1185    let theta_data = theta.data_vec()?;
1186    let theta_device = theta.device();
1187    let total = batch * h * w * 2;
1188    let mut grid = vec![T::from(0.0).unwrap(); total];
1189
1190    for b in 0..batch {
1191        let t_base = b * 6;
1192        let t00 = theta_data[t_base];
1193        let t01 = theta_data[t_base + 1];
1194        let t02 = theta_data[t_base + 2];
1195        let t10 = theta_data[t_base + 3];
1196        let t11 = theta_data[t_base + 4];
1197        let t12 = theta_data[t_base + 5];
1198
1199        for iy in 0..h {
1200            let y_norm = if align_corners {
1201                if h <= 1 {
1202                    T::from(0.0).unwrap()
1203                } else {
1204                    two * T::from(iy).unwrap() / T::from(h - 1).unwrap() - one
1205                }
1206            } else {
1207                (two * T::from(iy).unwrap() + one) / T::from(h).unwrap() - one
1208            };
1209
1210            for ix in 0..w {
1211                let x_norm = if align_corners {
1212                    if w <= 1 {
1213                        T::from(0.0).unwrap()
1214                    } else {
1215                        two * T::from(ix).unwrap() / T::from(w - 1).unwrap() - one
1216                    }
1217                } else {
1218                    (two * T::from(ix).unwrap() + one) / T::from(w).unwrap() - one
1219                };
1220
1221                let out_base = ((b * h + iy) * w + ix) * 2;
1222                grid[out_base] = t00 * x_norm + t01 * y_norm + t02;
1223                grid[out_base + 1] = t10 * x_norm + t11 * y_norm + t12;
1224            }
1225        }
1226    }
1227
1228    let out_shape = vec![batch, h, w, 2];
1229    let storage = TensorStorage::cpu(grid);
1230
1231    if is_grad_enabled() && theta.requires_grad() {
1232        Tensor::from_operation(
1233            storage,
1234            out_shape,
1235            Arc::new(AffineGridBackward {
1236                theta: theta.clone(),
1237                size,
1238                align_corners,
1239            }),
1240        )?
1241        .to(theta_device)
1242    } else {
1243        Tensor::from_storage(storage, out_shape, false)?.to(theta_device)
1244    }
1245}
1246
1247#[derive(Debug)]
1248struct AffineGridBackward<T: Float> {
1249    theta: Tensor<T>,
1250    size: [usize; 4],
1251    align_corners: bool,
1252}
1253
1254impl<T: Float> GradFn<T> for AffineGridBackward<T> {
1255    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1256        if !self.theta.requires_grad() {
1257            return Ok(vec![None]);
1258        }
1259
1260        let batch = self.size[0];
1261        let h = self.size[2];
1262        let w = self.size[3];
1263        let one = T::from(1.0).unwrap();
1264        let two = T::from(2.0).unwrap();
1265        let zero = T::from(0.0).unwrap();
1266
1267        let go_data = grad_output.data_vec()?;
1268        let mut grad_theta = vec![zero; batch * 6];
1269
1270        for b in 0..batch {
1271            for iy in 0..h {
1272                let y_norm = if self.align_corners {
1273                    if h <= 1 {
1274                        zero
1275                    } else {
1276                        two * T::from(iy).unwrap() / T::from(h - 1).unwrap() - one
1277                    }
1278                } else {
1279                    (two * T::from(iy).unwrap() + one) / T::from(h).unwrap() - one
1280                };
1281
1282                for ix in 0..w {
1283                    let x_norm = if self.align_corners {
1284                        if w <= 1 {
1285                            zero
1286                        } else {
1287                            two * T::from(ix).unwrap() / T::from(w - 1).unwrap() - one
1288                        }
1289                    } else {
1290                        (two * T::from(ix).unwrap() + one) / T::from(w).unwrap() - one
1291                    };
1292
1293                    let go_base = ((b * h + iy) * w + ix) * 2;
1294                    let gx = go_data[go_base];
1295                    let gy = go_data[go_base + 1];
1296
1297                    let t_base = b * 6;
1298                    // d(grid_x) / d(t00) = x_norm, d(grid_x) / d(t01) = y_norm, d(grid_x) / d(t02) = 1
1299                    grad_theta[t_base] += gx * x_norm;
1300                    grad_theta[t_base + 1] += gx * y_norm;
1301                    grad_theta[t_base + 2] += gx;
1302                    // d(grid_y) / d(t10) = x_norm, d(grid_y) / d(t11) = y_norm, d(grid_y) / d(t12) = 1
1303                    grad_theta[t_base + 3] += gy * x_norm;
1304                    grad_theta[t_base + 4] += gy * y_norm;
1305                    grad_theta[t_base + 5] += gy;
1306                }
1307            }
1308        }
1309
1310        let grad_tensor = Tensor::from_storage(
1311            TensorStorage::cpu(grad_theta),
1312            self.theta.shape().to_vec(),
1313            false,
1314        )?;
1315        Ok(vec![Some(grad_tensor)])
1316    }
1317
1318    fn inputs(&self) -> Vec<&Tensor<T>> {
1319        vec![&self.theta]
1320    }
1321
1322    fn name(&self) -> &'static str {
1323        "AffineGridBackward"
1324    }
1325}
1326
1327// ===========================================================================
1328// PixelShuffle / PixelUnshuffle
1329// ===========================================================================
1330
1331/// Rearranges `[B, C*r*r, H, W]` to `[B, C, H*r, W*r]` (sub-pixel convolution).
1332///
1333/// Used in super-resolution networks to upsample feature maps without
1334/// transposed convolutions.
1335///
1336/// CL-317
1337#[derive(Debug, Clone, Copy)]
1338pub struct PixelShuffle {
1339    /// Upscale factor.
1340    pub upscale_factor: usize,
1341}
1342
1343impl PixelShuffle {
1344    pub fn new(upscale_factor: usize) -> Self {
1345        Self { upscale_factor }
1346    }
1347}
1348
1349impl<T: Float> Module<T> for PixelShuffle {
1350    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1351        pixel_shuffle(input, self.upscale_factor)
1352    }
1353
1354    fn parameters(&self) -> Vec<&Parameter<T>> {
1355        vec![]
1356    }
1357    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1358        vec![]
1359    }
1360    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1361        vec![]
1362    }
1363    fn train(&mut self) {}
1364    fn eval(&mut self) {}
1365    fn is_training(&self) -> bool {
1366        false
1367    }
1368}
1369
1370/// Rearranges `[B, C, H*r, W*r]` to `[B, C*r*r, H, W]` (inverse sub-pixel convolution).
1371///
1372/// CL-317
1373#[derive(Debug, Clone, Copy)]
1374pub struct PixelUnshuffle {
1375    /// Downscale factor.
1376    pub downscale_factor: usize,
1377}
1378
1379impl PixelUnshuffle {
1380    pub fn new(downscale_factor: usize) -> Self {
1381        Self { downscale_factor }
1382    }
1383}
1384
1385impl<T: Float> Module<T> for PixelUnshuffle {
1386    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1387        pixel_unshuffle(input, self.downscale_factor)
1388    }
1389
1390    fn parameters(&self) -> Vec<&Parameter<T>> {
1391        vec![]
1392    }
1393    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1394        vec![]
1395    }
1396    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1397        vec![]
1398    }
1399    fn train(&mut self) {}
1400    fn eval(&mut self) {}
1401    fn is_training(&self) -> bool {
1402        false
1403    }
1404}
1405
1406/// Functional pixel shuffle: `[B, C*r*r, H, W]` -> `[B, C, H*r, W*r]`.
1407///
1408/// CL-317
1409pub fn pixel_shuffle<T: Float>(
1410    input: &Tensor<T>,
1411    upscale_factor: usize,
1412) -> FerrotorchResult<Tensor<T>> {
1413    let (batch, channels_in, h, w) = validate_4d(input, "pixel_shuffle")?;
1414    let r = upscale_factor;
1415
1416    if r == 0 {
1417        return Err(FerrotorchError::InvalidArgument {
1418            message: "pixel_shuffle: upscale_factor must be > 0".into(),
1419        });
1420    }
1421    if channels_in % (r * r) != 0 {
1422        return Err(FerrotorchError::InvalidArgument {
1423            message: format!(
1424                "pixel_shuffle: channels ({channels_in}) must be divisible by r^2 ({})",
1425                r * r
1426            ),
1427        });
1428    }
1429
1430    let c_out = channels_in / (r * r);
1431    let h_out = h * r;
1432    let w_out = w * r;
1433
1434    let input_device = input.device();
1435    let data = input.data_vec()?;
1436
1437    let total = batch * c_out * h_out * w_out;
1438    let mut output = vec![T::from(0.0).unwrap(); total];
1439
1440    // Layout: input channels are organized as [c, r_h, r_w] sub-groups.
1441    for b in 0..batch {
1442        for c in 0..c_out {
1443            for ih in 0..h {
1444                for iw in 0..w {
1445                    for rh in 0..r {
1446                        for rw in 0..r {
1447                            let in_c = c * r * r + rh * r + rw;
1448                            let in_idx = ((b * channels_in + in_c) * h + ih) * w + iw;
1449
1450                            let oh = ih * r + rh;
1451                            let ow_pos = iw * r + rw;
1452                            let out_idx = ((b * c_out + c) * h_out + oh) * w_out + ow_pos;
1453
1454                            output[out_idx] = data[in_idx];
1455                        }
1456                    }
1457                }
1458            }
1459        }
1460    }
1461
1462    let out_shape = vec![batch, c_out, h_out, w_out];
1463    let storage = TensorStorage::cpu(output);
1464
1465    if is_grad_enabled() && input.requires_grad() {
1466        Tensor::from_operation(
1467            storage,
1468            out_shape,
1469            Arc::new(PixelShuffleBackward {
1470                input: input.clone(),
1471                upscale_factor: r,
1472            }),
1473        )?
1474        .to(input_device)
1475    } else {
1476        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1477    }
1478}
1479
1480/// Functional pixel unshuffle: `[B, C, H*r, W*r]` -> `[B, C*r*r, H, W]`.
1481///
1482/// CL-317
1483pub fn pixel_unshuffle<T: Float>(
1484    input: &Tensor<T>,
1485    downscale_factor: usize,
1486) -> FerrotorchResult<Tensor<T>> {
1487    let (batch, channels, h_in, w_in) = validate_4d(input, "pixel_unshuffle")?;
1488    let r = downscale_factor;
1489
1490    if r == 0 {
1491        return Err(FerrotorchError::InvalidArgument {
1492            message: "pixel_unshuffle: downscale_factor must be > 0".into(),
1493        });
1494    }
1495    if h_in % r != 0 || w_in % r != 0 {
1496        return Err(FerrotorchError::InvalidArgument {
1497            message: format!(
1498                "pixel_unshuffle: spatial dims ({h_in}, {w_in}) must be divisible by r={r}"
1499            ),
1500        });
1501    }
1502
1503    let h_out = h_in / r;
1504    let w_out = w_in / r;
1505    let c_out = channels * r * r;
1506
1507    let input_device = input.device();
1508    let data = input.data_vec()?;
1509
1510    let total = batch * c_out * h_out * w_out;
1511    let mut output = vec![T::from(0.0).unwrap(); total];
1512
1513    for b in 0..batch {
1514        for c in 0..channels {
1515            for oh in 0..h_out {
1516                for ow in 0..w_out {
1517                    for rh in 0..r {
1518                        for rw in 0..r {
1519                            let in_h = oh * r + rh;
1520                            let in_w = ow * r + rw;
1521                            let in_idx = ((b * channels + c) * h_in + in_h) * w_in + in_w;
1522
1523                            let out_c = c * r * r + rh * r + rw;
1524                            let out_idx = ((b * c_out + out_c) * h_out + oh) * w_out + ow;
1525
1526                            output[out_idx] = data[in_idx];
1527                        }
1528                    }
1529                }
1530            }
1531        }
1532    }
1533
1534    let out_shape = vec![batch, c_out, h_out, w_out];
1535    let storage = TensorStorage::cpu(output);
1536
1537    if is_grad_enabled() && input.requires_grad() {
1538        Tensor::from_operation(
1539            storage,
1540            out_shape,
1541            Arc::new(PixelUnshuffleBackward {
1542                input: input.clone(),
1543                downscale_factor: r,
1544            }),
1545        )?
1546        .to(input_device)
1547    } else {
1548        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1549    }
1550}
1551
1552#[derive(Debug)]
1553struct PixelShuffleBackward<T: Float> {
1554    input: Tensor<T>,
1555    upscale_factor: usize,
1556}
1557
1558impl<T: Float> GradFn<T> for PixelShuffleBackward<T> {
1559    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1560        if !self.input.requires_grad() {
1561            return Ok(vec![None]);
1562        }
1563        // Backward of pixel_shuffle is pixel_unshuffle.
1564        let grad_input = pixel_unshuffle(grad_output, self.upscale_factor)?;
1565        Ok(vec![Some(grad_input)])
1566    }
1567
1568    fn inputs(&self) -> Vec<&Tensor<T>> {
1569        vec![&self.input]
1570    }
1571
1572    fn name(&self) -> &'static str {
1573        "PixelShuffleBackward"
1574    }
1575}
1576
1577#[derive(Debug)]
1578struct PixelUnshuffleBackward<T: Float> {
1579    input: Tensor<T>,
1580    downscale_factor: usize,
1581}
1582
1583impl<T: Float> GradFn<T> for PixelUnshuffleBackward<T> {
1584    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1585        if !self.input.requires_grad() {
1586            return Ok(vec![None]);
1587        }
1588        // Backward of pixel_unshuffle is pixel_shuffle.
1589        let grad_input = pixel_shuffle(grad_output, self.downscale_factor)?;
1590        Ok(vec![Some(grad_input)])
1591    }
1592
1593    fn inputs(&self) -> Vec<&Tensor<T>> {
1594        vec![&self.input]
1595    }
1596
1597    fn name(&self) -> &'static str {
1598        "PixelUnshuffleBackward"
1599    }
1600}
1601
1602// ===========================================================================
1603// Unfold / Fold
1604// ===========================================================================
1605
1606/// Extracts sliding-window patches from a `[B, C, H, W]` tensor and
1607/// reshapes them into columns: output `[B, C * kH * kW, L]` where
1608/// `L = out_h * out_w`.
1609///
1610/// This is the im2col operation used in efficient convolution implementations.
1611///
1612/// CL-317
1613#[derive(Debug, Clone, Copy)]
1614pub struct Unfold {
1615    pub kernel_size: [usize; 2],
1616    pub dilation: [usize; 2],
1617    pub padding: [usize; 2],
1618    pub stride: [usize; 2],
1619}
1620
1621impl Unfold {
1622    pub fn new(
1623        kernel_size: [usize; 2],
1624        dilation: [usize; 2],
1625        padding: [usize; 2],
1626        stride: [usize; 2],
1627    ) -> Self {
1628        Self {
1629            kernel_size,
1630            dilation,
1631            padding,
1632            stride,
1633        }
1634    }
1635}
1636
1637impl<T: Float> Module<T> for Unfold {
1638    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1639        unfold(
1640            input,
1641            self.kernel_size,
1642            self.dilation,
1643            self.padding,
1644            self.stride,
1645        )
1646    }
1647
1648    fn parameters(&self) -> Vec<&Parameter<T>> {
1649        vec![]
1650    }
1651    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1652        vec![]
1653    }
1654    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1655        vec![]
1656    }
1657    fn train(&mut self) {}
1658    fn eval(&mut self) {}
1659    fn is_training(&self) -> bool {
1660        false
1661    }
1662}
1663
1664/// Reconstructs a `[B, C, H, W]` tensor from sliding-window columns
1665/// `[B, C * kH * kW, L]`, the inverse of [`Unfold`].
1666///
1667/// `output_size` specifies the original `[H, W]` spatial dimensions.
1668///
1669/// CL-317
1670#[derive(Debug, Clone, Copy)]
1671pub struct Fold {
1672    pub output_size: [usize; 2],
1673    pub kernel_size: [usize; 2],
1674    pub dilation: [usize; 2],
1675    pub padding: [usize; 2],
1676    pub stride: [usize; 2],
1677}
1678
1679impl Fold {
1680    pub fn new(
1681        output_size: [usize; 2],
1682        kernel_size: [usize; 2],
1683        dilation: [usize; 2],
1684        padding: [usize; 2],
1685        stride: [usize; 2],
1686    ) -> Self {
1687        Self {
1688            output_size,
1689            kernel_size,
1690            dilation,
1691            padding,
1692            stride,
1693        }
1694    }
1695}
1696
1697impl<T: Float> Module<T> for Fold {
1698    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
1699        fold(
1700            input,
1701            self.output_size,
1702            self.kernel_size,
1703            self.dilation,
1704            self.padding,
1705            self.stride,
1706        )
1707    }
1708
1709    fn parameters(&self) -> Vec<&Parameter<T>> {
1710        vec![]
1711    }
1712    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
1713        vec![]
1714    }
1715    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
1716        vec![]
1717    }
1718    fn train(&mut self) {}
1719    fn eval(&mut self) {}
1720    fn is_training(&self) -> bool {
1721        false
1722    }
1723}
1724
1725/// Compute the output spatial dim for unfold given the parameters.
1726#[inline]
1727fn unfold_output_size(
1728    input_size: usize,
1729    kernel_size: usize,
1730    dilation: usize,
1731    padding: usize,
1732    stride: usize,
1733) -> usize {
1734    (input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
1735}
1736
1737/// Functional unfold (im2col): `[B, C, H, W]` -> `[B, C*kH*kW, L]`.
1738///
1739/// CL-317
1740pub fn unfold<T: Float>(
1741    input: &Tensor<T>,
1742    kernel_size: [usize; 2],
1743    dilation: [usize; 2],
1744    padding: [usize; 2],
1745    stride: [usize; 2],
1746) -> FerrotorchResult<Tensor<T>> {
1747    let (batch, channels, h, w) = validate_4d(input, "unfold")?;
1748
1749    if kernel_size[0] == 0
1750        || kernel_size[1] == 0
1751        || stride[0] == 0
1752        || stride[1] == 0
1753        || dilation[0] == 0
1754        || dilation[1] == 0
1755    {
1756        return Err(FerrotorchError::InvalidArgument {
1757            message: "unfold: kernel_size, stride, dilation must all be > 0".into(),
1758        });
1759    }
1760
1761    let out_h = unfold_output_size(h, kernel_size[0], dilation[0], padding[0], stride[0]);
1762    let out_w = unfold_output_size(w, kernel_size[1], dilation[1], padding[1], stride[1]);
1763    let l = out_h * out_w;
1764    let k = channels * kernel_size[0] * kernel_size[1];
1765
1766    let input_device = input.device();
1767    let data = input.data_vec()?;
1768
1769    let total = batch * k * l;
1770    let mut output = vec![T::from(0.0).unwrap(); total];
1771
1772    for b in 0..batch {
1773        for c in 0..channels {
1774            for kh in 0..kernel_size[0] {
1775                for kw in 0..kernel_size[1] {
1776                    let k_idx = (c * kernel_size[0] + kh) * kernel_size[1] + kw;
1777                    for oh in 0..out_h {
1778                        for ow in 0..out_w {
1779                            let ih = oh * stride[0] + kh * dilation[0];
1780                            let iw = ow * stride[1] + kw * dilation[1];
1781                            let ih = ih as isize - padding[0] as isize;
1782                            let iw = iw as isize - padding[1] as isize;
1783
1784                            let l_idx = oh * out_w + ow;
1785                            let out_idx = (b * k + k_idx) * l + l_idx;
1786
1787                            if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
1788                                let in_idx =
1789                                    ((b * channels + c) * h + ih as usize) * w + iw as usize;
1790                                output[out_idx] = data[in_idx];
1791                            }
1792                            // else stays zero (padded)
1793                        }
1794                    }
1795                }
1796            }
1797        }
1798    }
1799
1800    let out_shape = vec![batch, k, l];
1801    let storage = TensorStorage::cpu(output);
1802
1803    if is_grad_enabled() && input.requires_grad() {
1804        Tensor::from_operation(
1805            storage,
1806            out_shape,
1807            Arc::new(UnfoldBackward {
1808                input: input.clone(),
1809                kernel_size,
1810                dilation,
1811                padding,
1812                stride,
1813            }),
1814        )?
1815        .to(input_device)
1816    } else {
1817        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1818    }
1819}
1820
1821/// Functional fold (col2im): `[B, C*kH*kW, L]` -> `[B, C, H, W]`.
1822///
1823/// CL-317
1824pub fn fold<T: Float>(
1825    input: &Tensor<T>,
1826    output_size: [usize; 2],
1827    kernel_size: [usize; 2],
1828    dilation: [usize; 2],
1829    padding: [usize; 2],
1830    stride: [usize; 2],
1831) -> FerrotorchResult<Tensor<T>> {
1832    let shape = input.shape();
1833    if shape.len() != 3 {
1834        return Err(FerrotorchError::InvalidArgument {
1835            message: format!(
1836                "fold expects 3D input [B, C*kH*kW, L], got shape {:?}",
1837                shape
1838            ),
1839        });
1840    }
1841
1842    if kernel_size[0] == 0
1843        || kernel_size[1] == 0
1844        || stride[0] == 0
1845        || stride[1] == 0
1846        || dilation[0] == 0
1847        || dilation[1] == 0
1848    {
1849        return Err(FerrotorchError::InvalidArgument {
1850            message: "fold: kernel_size, stride, dilation must all be > 0".into(),
1851        });
1852    }
1853
1854    let batch = shape[0];
1855    let k = shape[1]; // C * kH * kW
1856    let l = shape[2]; // out_h * out_w
1857
1858    let [h_out, w_out] = output_size;
1859    let k_area = kernel_size[0] * kernel_size[1];
1860
1861    if k % k_area != 0 {
1862        return Err(FerrotorchError::InvalidArgument {
1863            message: format!("fold: dim 1 ({k}) must be divisible by kH*kW ({})", k_area),
1864        });
1865    }
1866    let channels = k / k_area;
1867
1868    let expected_out_h =
1869        unfold_output_size(h_out, kernel_size[0], dilation[0], padding[0], stride[0]);
1870    let expected_out_w =
1871        unfold_output_size(w_out, kernel_size[1], dilation[1], padding[1], stride[1]);
1872    let expected_l = expected_out_h * expected_out_w;
1873
1874    if l != expected_l {
1875        return Err(FerrotorchError::InvalidArgument {
1876            message: format!(
1877                "fold: L={l} does not match expected {expected_l} for output_size ({h_out}, {w_out})"
1878            ),
1879        });
1880    }
1881
1882    let input_device = input.device();
1883    let data = input.data_vec()?;
1884
1885    let total = batch * channels * h_out * w_out;
1886    let mut output = vec![T::from(0.0).unwrap(); total];
1887
1888    for b in 0..batch {
1889        for c in 0..channels {
1890            for kh in 0..kernel_size[0] {
1891                for kw in 0..kernel_size[1] {
1892                    let k_idx = (c * kernel_size[0] + kh) * kernel_size[1] + kw;
1893                    for oh in 0..expected_out_h {
1894                        for ow in 0..expected_out_w {
1895                            let ih = oh * stride[0] + kh * dilation[0];
1896                            let iw = ow * stride[1] + kw * dilation[1];
1897                            let ih = ih as isize - padding[0] as isize;
1898                            let iw = iw as isize - padding[1] as isize;
1899
1900                            if ih >= 0 && ih < h_out as isize && iw >= 0 && iw < w_out as isize {
1901                                let l_idx = oh * expected_out_w + ow;
1902                                let in_idx = (b * k + k_idx) * l + l_idx;
1903                                let out_idx = ((b * channels + c) * h_out + ih as usize) * w_out
1904                                    + iw as usize;
1905                                output[out_idx] += data[in_idx];
1906                            }
1907                        }
1908                    }
1909                }
1910            }
1911        }
1912    }
1913
1914    let out_shape = vec![batch, channels, h_out, w_out];
1915    let storage = TensorStorage::cpu(output);
1916
1917    if is_grad_enabled() && input.requires_grad() {
1918        Tensor::from_operation(
1919            storage,
1920            out_shape,
1921            Arc::new(FoldBackward {
1922                input: input.clone(),
1923                kernel_size,
1924                dilation,
1925                padding,
1926                stride,
1927            }),
1928        )?
1929        .to(input_device)
1930    } else {
1931        Tensor::from_storage(storage, out_shape, false)?.to(input_device)
1932    }
1933}
1934
1935#[derive(Debug)]
1936struct UnfoldBackward<T: Float> {
1937    input: Tensor<T>,
1938    kernel_size: [usize; 2],
1939    dilation: [usize; 2],
1940    padding: [usize; 2],
1941    stride: [usize; 2],
1942}
1943
1944impl<T: Float> GradFn<T> for UnfoldBackward<T> {
1945    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1946        if !self.input.requires_grad() {
1947            return Ok(vec![None]);
1948        }
1949        // Backward of unfold is fold.
1950        let in_shape = self.input.shape();
1951        let h = in_shape[2];
1952        let w = in_shape[3];
1953        let grad_input = fold(
1954            grad_output,
1955            [h, w],
1956            self.kernel_size,
1957            self.dilation,
1958            self.padding,
1959            self.stride,
1960        )?;
1961        Ok(vec![Some(grad_input)])
1962    }
1963
1964    fn inputs(&self) -> Vec<&Tensor<T>> {
1965        vec![&self.input]
1966    }
1967
1968    fn name(&self) -> &'static str {
1969        "UnfoldBackward"
1970    }
1971}
1972
1973#[derive(Debug)]
1974struct FoldBackward<T: Float> {
1975    input: Tensor<T>,
1976    kernel_size: [usize; 2],
1977    dilation: [usize; 2],
1978    padding: [usize; 2],
1979    stride: [usize; 2],
1980}
1981
1982impl<T: Float> GradFn<T> for FoldBackward<T> {
1983    fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
1984        if !self.input.requires_grad() {
1985            return Ok(vec![None]);
1986        }
1987        // Backward of fold is unfold.
1988        let grad_input = unfold(
1989            grad_output,
1990            self.kernel_size,
1991            self.dilation,
1992            self.padding,
1993            self.stride,
1994        )?;
1995        Ok(vec![Some(grad_input)])
1996    }
1997
1998    fn inputs(&self) -> Vec<&Tensor<T>> {
1999        vec![&self.input]
2000    }
2001
2002    fn name(&self) -> &'static str {
2003        "FoldBackward"
2004    }
2005}
2006
2007// ===========================================================================
2008// Tests
2009// ===========================================================================
2010
2011#[cfg(test)]
2012mod tests {
2013    use super::*;
2014
2015    /// Create a leaf 4D tensor from flat data.
2016    fn leaf_4d(data: &[f32], shape: [usize; 4], requires_grad: bool) -> Tensor<f32> {
2017        Tensor::from_storage(
2018            TensorStorage::cpu(data.to_vec()),
2019            shape.to_vec(),
2020            requires_grad,
2021        )
2022        .unwrap()
2023    }
2024
2025    /// Create a leaf tensor with any shape.
2026    fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
2027        Tensor::from_storage(
2028            TensorStorage::cpu(data.to_vec()),
2029            shape.to_vec(),
2030            requires_grad,
2031        )
2032        .unwrap()
2033    }
2034
2035    fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
2036        assert_eq!(
2037            actual.len(),
2038            expected.len(),
2039            "length mismatch: {} vs {}",
2040            actual.len(),
2041            expected.len()
2042        );
2043        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
2044            assert!(
2045                (a - e).abs() < tol,
2046                "index {i}: actual={a} expected={e} diff={}",
2047                (a - e).abs(),
2048            );
2049        }
2050    }
2051
2052    // -----------------------------------------------------------------------
2053    // Interpolation: Nearest
2054    // -----------------------------------------------------------------------
2055
2056    #[test]
2057    fn test_interpolate_nearest_upsample_2x() {
2058        // [1, 1, 2, 2] -> [1, 1, 4, 4]
2059        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2060        let input = leaf_4d(&data, [1, 1, 2, 2], false);
2061        let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
2062        assert_eq!(out.shape(), &[1, 1, 4, 4]);
2063
2064        let d = out.data().unwrap();
2065        // Each 2x2 block should repeat the source pixel.
2066        #[rustfmt::skip]
2067        let expected: Vec<f32> = vec![
2068            1.0, 1.0, 2.0, 2.0,
2069            1.0, 1.0, 2.0, 2.0,
2070            3.0, 3.0, 4.0, 4.0,
2071            3.0, 3.0, 4.0, 4.0,
2072        ];
2073        assert_close(d, &expected, 1e-6);
2074    }
2075
2076    #[test]
2077    fn test_interpolate_nearest_downsample() {
2078        // [1, 1, 4, 4] -> [1, 1, 2, 2]
2079        #[rustfmt::skip]
2080        let data: Vec<f32> = vec![
2081            1.0, 2.0, 3.0, 4.0,
2082            5.0, 6.0, 7.0, 8.0,
2083            9.0, 10.0, 11.0, 12.0,
2084            13.0, 14.0, 15.0, 16.0,
2085        ];
2086        let input = leaf_4d(&data, [1, 1, 4, 4], false);
2087        let out = interpolate(&input, Some([2, 2]), None, InterpolateMode::Nearest, false).unwrap();
2088        assert_eq!(out.shape(), &[1, 1, 2, 2]);
2089        let d = out.data().unwrap();
2090        // floor(0 * 2) = 0, floor(1 * 2) = 2; so we pick (0,0)=1, (0,2)=3, (2,0)=9, (2,2)=11
2091        assert_close(d, &[1.0, 3.0, 9.0, 11.0], 1e-6);
2092    }
2093
2094    #[test]
2095    fn test_interpolate_nearest_scale_factor() {
2096        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2097        let input = leaf_4d(&data, [1, 1, 2, 2], false);
2098        let out = interpolate(
2099            &input,
2100            None,
2101            Some([2.0, 2.0]),
2102            InterpolateMode::Nearest,
2103            false,
2104        )
2105        .unwrap();
2106        assert_eq!(out.shape(), &[1, 1, 4, 4]);
2107    }
2108
2109    // -----------------------------------------------------------------------
2110    // Interpolation: Bilinear
2111    // -----------------------------------------------------------------------
2112
2113    #[test]
2114    fn test_interpolate_bilinear_upsample() {
2115        // [1, 1, 2, 2] with align_corners=true -> [1, 1, 3, 3]
2116        let data: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0];
2117        let input = leaf_4d(&data, [1, 1, 2, 2], false);
2118        let out = interpolate(&input, Some([3, 3]), None, InterpolateMode::Bilinear, true).unwrap();
2119        assert_eq!(out.shape(), &[1, 1, 3, 3]);
2120
2121        let d = out.data().unwrap();
2122        // Corners should match exactly.
2123        assert!((d[0] - 0.0).abs() < 1e-5); // top-left
2124        assert!((d[2] - 1.0).abs() < 1e-5); // top-right
2125        assert!((d[6] - 2.0).abs() < 1e-5); // bottom-left
2126        assert!((d[8] - 3.0).abs() < 1e-5); // bottom-right
2127        // Center should be average of all corners.
2128        assert!((d[4] - 1.5).abs() < 1e-5);
2129    }
2130
2131    #[test]
2132    fn test_interpolate_bilinear_identity() {
2133        // Same size should be approximately identity.
2134        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
2135        let input = leaf_4d(&data, [1, 1, 3, 3], false);
2136        let out = interpolate(&input, Some([3, 3]), None, InterpolateMode::Bilinear, true).unwrap();
2137        assert_eq!(out.shape(), &[1, 1, 3, 3]);
2138        assert_close(out.data().unwrap(), &data, 1e-5);
2139    }
2140
2141    // -----------------------------------------------------------------------
2142    // Interpolation: Bicubic
2143    // -----------------------------------------------------------------------
2144
2145    #[test]
2146    fn test_interpolate_bicubic_output_shape() {
2147        let data: Vec<f32> = vec![0.0; 64];
2148        let input = leaf_4d(&data, [1, 1, 8, 8], false);
2149        let out = interpolate(
2150            &input,
2151            Some([16, 16]),
2152            None,
2153            InterpolateMode::Bicubic,
2154            false,
2155        )
2156        .unwrap();
2157        assert_eq!(out.shape(), &[1, 1, 16, 16]);
2158    }
2159
2160    #[test]
2161    fn test_interpolate_bicubic_corners_align() {
2162        // With align_corners=true, the 4 corners of a 2x2 input should
2163        // map exactly to the 4 corners of the output.
2164        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2165        let input = leaf_4d(&data, [1, 1, 2, 2], false);
2166        let out = interpolate(&input, Some([5, 5]), None, InterpolateMode::Bicubic, true).unwrap();
2167        assert_eq!(out.shape(), &[1, 1, 5, 5]);
2168        let d = out.data().unwrap();
2169        assert!((d[0] - 1.0).abs() < 1e-4); // top-left
2170        assert!((d[4] - 2.0).abs() < 1e-4); // top-right
2171        assert!((d[20] - 3.0).abs() < 1e-4); // bottom-left
2172        assert!((d[24] - 4.0).abs() < 1e-4); // bottom-right
2173    }
2174
2175    // -----------------------------------------------------------------------
2176    // Upsample module
2177    // -----------------------------------------------------------------------
2178
2179    #[test]
2180    fn test_upsample_module_nearest() {
2181        let up = Upsample::new([6, 6], InterpolateMode::Nearest);
2182        let input = leaf_4d(&[0.0; 9], [1, 1, 3, 3], false);
2183        let out: Tensor<f32> = Module::<f32>::forward(&up, &input).unwrap();
2184        assert_eq!(out.shape(), &[1, 1, 6, 6]);
2185    }
2186
2187    #[test]
2188    fn test_upsample_module_bilinear_scale() {
2189        let up = Upsample::with_scale_factor([2.0, 2.0], InterpolateMode::Bilinear);
2190        let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2191        let out: Tensor<f32> = Module::<f32>::forward(&up, &input).unwrap();
2192        assert_eq!(out.shape(), &[1, 1, 4, 4]);
2193    }
2194
2195    #[test]
2196    fn test_upsample_no_parameters() {
2197        let up = Upsample::new([4, 4], InterpolateMode::Nearest);
2198        assert!(Module::<f32>::parameters(&up).is_empty());
2199    }
2200
2201    // -----------------------------------------------------------------------
2202    // Interpolate errors
2203    // -----------------------------------------------------------------------
2204
2205    #[test]
2206    fn test_interpolate_no_size_no_scale() {
2207        let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2208        assert!(interpolate(&input, None, None, InterpolateMode::Nearest, false).is_err());
2209    }
2210
2211    #[test]
2212    fn test_interpolate_both_size_and_scale() {
2213        let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2214        assert!(
2215            interpolate(
2216                &input,
2217                Some([4, 4]),
2218                Some([2.0, 2.0]),
2219                InterpolateMode::Nearest,
2220                false
2221            )
2222            .is_err()
2223        );
2224    }
2225
2226    #[test]
2227    fn test_interpolate_nearest_align_corners_rejected() {
2228        let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2229        assert!(interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, true).is_err());
2230    }
2231
2232    #[test]
2233    fn test_interpolate_3d_rejected() {
2234        let input = leaf(&[0.0; 6], &[2, 3], false);
2235        assert!(interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).is_err());
2236    }
2237
2238    // -----------------------------------------------------------------------
2239    // Interpolate backward
2240    // -----------------------------------------------------------------------
2241
2242    #[test]
2243    fn test_interpolate_nearest_backward() {
2244        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2245        let input = leaf_4d(&data, [1, 1, 2, 2], true);
2246        let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
2247
2248        let out_data = out.data().unwrap().to_vec();
2249        let total: f32 = out_data.iter().sum();
2250        let loss = Tensor::from_operation(
2251            TensorStorage::cpu(vec![total]),
2252            vec![],
2253            Arc::new(TestSumBackward { input: out }),
2254        )
2255        .unwrap();
2256        loss.backward().unwrap();
2257
2258        let grad = input.grad().unwrap().unwrap();
2259        let g = grad.data().unwrap();
2260        // Each input pixel maps to 4 output pixels (2x2 block), so grad = 4.0 for each.
2261        for (i, &val) in g.iter().enumerate() {
2262            assert!(
2263                (val - 4.0).abs() < 1e-5,
2264                "grad[{i}]: expected 4.0, got {val}"
2265            );
2266        }
2267    }
2268
2269    #[test]
2270    fn test_interpolate_bilinear_backward() {
2271        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2272        let input = leaf_4d(&data, [1, 1, 2, 2], true);
2273        let out =
2274            interpolate(&input, Some([4, 4]), None, InterpolateMode::Bilinear, false).unwrap();
2275
2276        let out_data = out.data().unwrap().to_vec();
2277        let total: f32 = out_data.iter().sum();
2278        let loss = Tensor::from_operation(
2279            TensorStorage::cpu(vec![total]),
2280            vec![],
2281            Arc::new(TestSumBackward { input: out }),
2282        )
2283        .unwrap();
2284        loss.backward().unwrap();
2285
2286        let grad = input.grad().unwrap().unwrap();
2287        let g = grad.data().unwrap();
2288        // Check gradient is non-zero and sums correctly.
2289        let grad_sum: f32 = g.iter().sum();
2290        // Sum of gradient = number of output elements (16).
2291        assert!(
2292            (grad_sum - 16.0).abs() < 1e-3,
2293            "grad sum = {grad_sum}, expected 16.0"
2294        );
2295    }
2296
2297    // -----------------------------------------------------------------------
2298    // PixelShuffle
2299    // -----------------------------------------------------------------------
2300
2301    #[test]
2302    fn test_pixel_shuffle_shape() {
2303        // [1, 4, 2, 2] with r=2 -> [1, 1, 4, 4]
2304        let data = vec![0.0f32; 16];
2305        let input = leaf_4d(&data, [1, 4, 2, 2], false);
2306        let out = pixel_shuffle(&input, 2).unwrap();
2307        assert_eq!(out.shape(), &[1, 1, 4, 4]);
2308    }
2309
2310    #[test]
2311    fn test_pixel_shuffle_values() {
2312        // [1, 4, 1, 1] with r=2 -> [1, 1, 2, 2]
2313        // Input channels: [c0_r0c0, c0_r0c1, c0_r1c0, c0_r1c1] = [1, 2, 3, 4]
2314        // Output: [[1, 2], [3, 4]]
2315        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2316        let input = leaf_4d(&data, [1, 4, 1, 1], false);
2317        let out = pixel_shuffle(&input, 2).unwrap();
2318        assert_eq!(out.shape(), &[1, 1, 2, 2]);
2319        assert_close(out.data().unwrap(), &[1.0, 2.0, 3.0, 4.0], 1e-6);
2320    }
2321
2322    #[test]
2323    fn test_pixel_shuffle_not_divisible() {
2324        // channels=3 not divisible by r^2=4
2325        let input = leaf_4d(&[0.0; 12], [1, 3, 2, 2], false);
2326        assert!(pixel_shuffle(&input, 2).is_err());
2327    }
2328
2329    // -----------------------------------------------------------------------
2330    // PixelUnshuffle
2331    // -----------------------------------------------------------------------
2332
2333    #[test]
2334    fn test_pixel_unshuffle_shape() {
2335        // [1, 1, 4, 4] with r=2 -> [1, 4, 2, 2]
2336        let data = vec![0.0f32; 16];
2337        let input = leaf_4d(&data, [1, 1, 4, 4], false);
2338        let out = pixel_unshuffle(&input, 2).unwrap();
2339        assert_eq!(out.shape(), &[1, 4, 2, 2]);
2340    }
2341
2342    #[test]
2343    fn test_pixel_shuffle_unshuffle_roundtrip() {
2344        // Shuffle then unshuffle should give back the original.
2345        let data: Vec<f32> = (0..36).map(|i| i as f32).collect();
2346        let input = leaf_4d(&data, [1, 4, 3, 3], false);
2347        let shuffled = pixel_shuffle(&input, 2).unwrap();
2348        assert_eq!(shuffled.shape(), &[1, 1, 6, 6]);
2349        let roundtrip = pixel_unshuffle(&shuffled, 2).unwrap();
2350        assert_eq!(roundtrip.shape(), &[1, 4, 3, 3]);
2351        assert_close(roundtrip.data().unwrap(), &data, 1e-6);
2352    }
2353
2354    #[test]
2355    fn test_pixel_unshuffle_spatial_not_divisible() {
2356        // H=3 not divisible by r=2
2357        let input = leaf_4d(&[0.0; 9], [1, 1, 3, 3], false);
2358        assert!(pixel_unshuffle(&input, 2).is_err());
2359    }
2360
2361    // -----------------------------------------------------------------------
2362    // PixelShuffle backward
2363    // -----------------------------------------------------------------------
2364
2365    #[test]
2366    fn test_pixel_shuffle_backward() {
2367        let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
2368        let input = leaf_4d(&data, [1, 4, 2, 2], true);
2369        let out = pixel_shuffle(&input, 2).unwrap();
2370
2371        let out_data = out.data().unwrap().to_vec();
2372        let total: f32 = out_data.iter().sum();
2373        let loss = Tensor::from_operation(
2374            TensorStorage::cpu(vec![total]),
2375            vec![],
2376            Arc::new(TestSumBackward { input: out }),
2377        )
2378        .unwrap();
2379        loss.backward().unwrap();
2380
2381        let grad = input.grad().unwrap().unwrap();
2382        let g = grad.data().unwrap();
2383        // Gradient of sum = 1 everywhere.
2384        for (i, &val) in g.iter().enumerate() {
2385            assert!(
2386                (val - 1.0).abs() < 1e-5,
2387                "grad[{i}]: expected 1.0, got {val}"
2388            );
2389        }
2390    }
2391
2392    // -----------------------------------------------------------------------
2393    // Unfold
2394    // -----------------------------------------------------------------------
2395
2396    #[test]
2397    fn test_unfold_shape() {
2398        // [1, 1, 4, 4], kernel 2x2, stride 1, no padding, no dilation
2399        // out_h = (4 - 2) / 1 + 1 = 3, out_w = 3, L = 9, k = 1*2*2 = 4
2400        let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2401        let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
2402        assert_eq!(out.shape(), &[1, 4, 9]);
2403    }
2404
2405    #[test]
2406    fn test_unfold_values() {
2407        // [1, 1, 3, 3], kernel 2x2, stride 1
2408        // Input:
2409        // 1 2 3
2410        // 4 5 6
2411        // 7 8 9
2412        //
2413        // Patches (each window is 4 elements, 4 windows):
2414        // window (0,0): [1, 2, 4, 5]
2415        // window (0,1): [2, 3, 5, 6]
2416        // window (1,0): [4, 5, 7, 8]
2417        // window (1,1): [5, 6, 8, 9]
2418        //
2419        // Output [1, 4, 4]:
2420        // k=0 (c=0, kh=0, kw=0): [1, 2, 4, 5]
2421        // k=1 (c=0, kh=0, kw=1): [2, 3, 5, 6]
2422        // k=2 (c=0, kh=1, kw=0): [4, 5, 7, 8]
2423        // k=3 (c=0, kh=1, kw=1): [5, 6, 8, 9]
2424        #[rustfmt::skip]
2425        let data: Vec<f32> = vec![
2426            1.0, 2.0, 3.0,
2427            4.0, 5.0, 6.0,
2428            7.0, 8.0, 9.0,
2429        ];
2430        let input = leaf_4d(&data, [1, 1, 3, 3], false);
2431        let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
2432        assert_eq!(out.shape(), &[1, 4, 4]);
2433
2434        let d = out.data().unwrap();
2435        assert_close(&d[0..4], &[1.0, 2.0, 4.0, 5.0], 1e-6);
2436        assert_close(&d[4..8], &[2.0, 3.0, 5.0, 6.0], 1e-6);
2437        assert_close(&d[8..12], &[4.0, 5.0, 7.0, 8.0], 1e-6);
2438        assert_close(&d[12..16], &[5.0, 6.0, 8.0, 9.0], 1e-6);
2439    }
2440
2441    #[test]
2442    fn test_unfold_with_padding() {
2443        // [1, 1, 2, 2], kernel 2x2, stride 1, padding 1
2444        // Padded: [4, 4], out_h = (2 + 2*1 - 2)/1 + 1 = 3, L = 9, k = 4
2445        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2446        let input = leaf_4d(&data, [1, 1, 2, 2], false);
2447        let out = unfold(&input, [2, 2], [1, 1], [1, 1], [1, 1]).unwrap();
2448        assert_eq!(out.shape(), &[1, 4, 9]);
2449    }
2450
2451    #[test]
2452    fn test_unfold_with_stride() {
2453        // [1, 1, 4, 4], kernel 2x2, stride 2
2454        // out_h = (4 - 2)/2 + 1 = 2, L = 4, k = 4
2455        let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2456        let out = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
2457        assert_eq!(out.shape(), &[1, 4, 4]);
2458    }
2459
2460    #[test]
2461    fn test_unfold_zero_kernel_rejected() {
2462        let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2463        assert!(unfold(&input, [0, 2], [1, 1], [0, 0], [1, 1]).is_err());
2464    }
2465
2466    // -----------------------------------------------------------------------
2467    // Fold
2468    // -----------------------------------------------------------------------
2469
2470    #[test]
2471    fn test_fold_shape() {
2472        // [1, 4, 9] -> [1, 1, 4, 4] with kernel 2x2, stride 1
2473        let data = vec![0.0f32; 36];
2474        let input = leaf(&data, &[1, 4, 9], false);
2475        let out = fold(&input, [4, 4], [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
2476        assert_eq!(out.shape(), &[1, 1, 4, 4]);
2477    }
2478
2479    #[test]
2480    fn test_unfold_fold_roundtrip() {
2481        // For non-overlapping patches (stride=kernel), fold(unfold(x)) == x.
2482        #[rustfmt::skip]
2483        let data: Vec<f32> = vec![
2484            1.0, 2.0, 3.0, 4.0,
2485            5.0, 6.0, 7.0, 8.0,
2486            9.0, 10.0, 11.0, 12.0,
2487            13.0, 14.0, 15.0, 16.0,
2488        ];
2489        let input = leaf_4d(&data, [1, 1, 4, 4], false);
2490        let unfolded = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
2491        let refolded = fold(&unfolded, [4, 4], [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
2492        assert_eq!(refolded.shape(), &[1, 1, 4, 4]);
2493        assert_close(refolded.data().unwrap(), &data, 1e-6);
2494    }
2495
2496    #[test]
2497    fn test_fold_l_mismatch() {
2498        // L doesn't match output_size
2499        let data = vec![0.0f32; 20];
2500        let input = leaf(&data, &[1, 4, 5], false);
2501        assert!(fold(&input, [4, 4], [2, 2], [1, 1], [0, 0], [1, 1]).is_err());
2502    }
2503
2504    // -----------------------------------------------------------------------
2505    // grid_sample
2506    // -----------------------------------------------------------------------
2507
2508    #[test]
2509    fn test_grid_sample_identity() {
2510        // Create a grid that maps each output pixel to itself (identity transform).
2511        // align_corners=true, so normalized coords are [-1, 1].
2512        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2513        let input = leaf_4d(&data, [1, 1, 2, 2], false);
2514
2515        // Grid for 2x2 output:
2516        // (0,0) -> (-1,-1), (0,1) -> (1,-1), (1,0) -> (-1,1), (1,1) -> (1,1)
2517        let grid_data: Vec<f32> = vec![-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
2518        let grid = leaf(&grid_data, &[1, 2, 2, 2], false);
2519
2520        let out = grid_sample(
2521            &input,
2522            &grid,
2523            GridSampleMode::Bilinear,
2524            GridSamplePaddingMode::Zeros,
2525            true,
2526        )
2527        .unwrap();
2528        assert_eq!(out.shape(), &[1, 1, 2, 2]);
2529        assert_close(out.data().unwrap(), &data, 1e-5);
2530    }
2531
2532    #[test]
2533    fn test_grid_sample_nearest() {
2534        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
2535        let input = leaf_4d(&data, [1, 1, 2, 2], false);
2536
2537        // Grid that samples center of pixel (0,0) -> should get 1.0
2538        let grid_data: Vec<f32> = vec![-1.0, -1.0];
2539        let grid = leaf(&grid_data, &[1, 1, 1, 2], false);
2540
2541        let out = grid_sample(
2542            &input,
2543            &grid,
2544            GridSampleMode::Nearest,
2545            GridSamplePaddingMode::Zeros,
2546            true,
2547        )
2548        .unwrap();
2549        assert_eq!(out.shape(), &[1, 1, 1, 1]);
2550        assert!((out.data().unwrap()[0] - 1.0).abs() < 1e-5);
2551    }
2552
2553    #[test]
2554    fn test_grid_sample_batch_mismatch() {
2555        let input = leaf_4d(&[0.0; 8], [2, 1, 2, 2], false);
2556        let grid = leaf(&[0.0; 8], &[1, 2, 2, 2], false);
2557        assert!(
2558            grid_sample(
2559                &input,
2560                &grid,
2561                GridSampleMode::Bilinear,
2562                GridSamplePaddingMode::Zeros,
2563                true
2564            )
2565            .is_err()
2566        );
2567    }
2568
2569    #[test]
2570    fn test_grid_sample_wrong_grid_shape() {
2571        let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
2572        let grid = leaf(&[0.0; 8], &[1, 2, 4], false);
2573        assert!(
2574            grid_sample(
2575                &input,
2576                &grid,
2577                GridSampleMode::Bilinear,
2578                GridSamplePaddingMode::Zeros,
2579                true
2580            )
2581            .is_err()
2582        );
2583    }
2584
2585    // -----------------------------------------------------------------------
2586    // affine_grid
2587    // -----------------------------------------------------------------------
2588
2589    #[test]
2590    fn test_affine_grid_identity() {
2591        // Identity transform: [[1, 0, 0], [0, 1, 0]]
2592        let theta_data: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
2593        let theta = leaf(&theta_data, &[1, 2, 3], false);
2594        let grid = affine_grid(&theta, [1, 1, 3, 3], true).unwrap();
2595        assert_eq!(grid.shape(), &[1, 3, 3, 2]);
2596
2597        let d = grid.data().unwrap();
2598        // Corners of the grid should be at (-1,-1), (1,-1), (-1,1), (1,1)
2599        // Top-left: (iy=0, ix=0) -> x=-1, y=-1
2600        assert!((d[0] - (-1.0)).abs() < 1e-5); // x
2601        assert!((d[1] - (-1.0)).abs() < 1e-5); // y
2602        // Top-right: (iy=0, ix=2) -> x=1, y=-1
2603        assert!((d[4] - 1.0).abs() < 1e-5); // x
2604        assert!((d[5] - (-1.0)).abs() < 1e-5); // y
2605    }
2606
2607    #[test]
2608    fn test_affine_grid_theta_shape_error() {
2609        let theta = leaf(&[0.0; 12], &[2, 3, 2], false);
2610        assert!(affine_grid(&theta, [2, 1, 3, 3], true).is_err());
2611    }
2612
2613    #[test]
2614    fn test_affine_grid_batch_mismatch() {
2615        let theta = leaf(&[0.0; 6], &[1, 2, 3], false);
2616        assert!(affine_grid(&theta, [2, 1, 3, 3], true).is_err());
2617    }
2618
2619    // -----------------------------------------------------------------------
2620    // PixelShuffle / PixelUnshuffle module
2621    // -----------------------------------------------------------------------
2622
2623    #[test]
2624    fn test_pixel_shuffle_module() {
2625        let ps = PixelShuffle::new(2);
2626        let input = leaf_4d(&[0.0; 16], [1, 4, 2, 2], false);
2627        let out: Tensor<f32> = Module::<f32>::forward(&ps, &input).unwrap();
2628        assert_eq!(out.shape(), &[1, 1, 4, 4]);
2629    }
2630
2631    #[test]
2632    fn test_pixel_unshuffle_module() {
2633        let pus = PixelUnshuffle::new(2);
2634        let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2635        let out: Tensor<f32> = Module::<f32>::forward(&pus, &input).unwrap();
2636        assert_eq!(out.shape(), &[1, 4, 2, 2]);
2637    }
2638
2639    // -----------------------------------------------------------------------
2640    // Unfold / Fold modules
2641    // -----------------------------------------------------------------------
2642
2643    #[test]
2644    fn test_unfold_module() {
2645        let uf = Unfold::new([2, 2], [1, 1], [0, 0], [1, 1]);
2646        let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
2647        let out: Tensor<f32> = Module::<f32>::forward(&uf, &input).unwrap();
2648        assert_eq!(out.shape(), &[1, 4, 9]);
2649    }
2650
2651    #[test]
2652    fn test_fold_module() {
2653        let f = Fold::new([4, 4], [2, 2], [1, 1], [0, 0], [1, 1]);
2654        let data = vec![0.0f32; 36];
2655        let input = leaf(&data, &[1, 4, 9], false);
2656        let out: Tensor<f32> = Module::<f32>::forward(&f, &input).unwrap();
2657        assert_eq!(out.shape(), &[1, 1, 4, 4]);
2658    }
2659
2660    // -----------------------------------------------------------------------
2661    // Unfold backward
2662    // -----------------------------------------------------------------------
2663
2664    #[test]
2665    fn test_unfold_backward() {
2666        // Non-overlapping unfold (stride = kernel): fold(unfold(x)) = x, so
2667        // gradient through sum(unfold(x)) should be all 1s.
2668        let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
2669        let input = leaf_4d(&data, [1, 1, 4, 4], true);
2670        let out = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
2671
2672        let out_data = out.data().unwrap().to_vec();
2673        let total: f32 = out_data.iter().sum();
2674        let loss = Tensor::from_operation(
2675            TensorStorage::cpu(vec![total]),
2676            vec![],
2677            Arc::new(TestSumBackward { input: out }),
2678        )
2679        .unwrap();
2680        loss.backward().unwrap();
2681
2682        let grad = input.grad().unwrap().unwrap();
2683        let g = grad.data().unwrap();
2684        for (i, &val) in g.iter().enumerate() {
2685            assert!(
2686                (val - 1.0).abs() < 1e-5,
2687                "grad[{i}]: expected 1.0, got {val}"
2688            );
2689        }
2690    }
2691
2692    // -----------------------------------------------------------------------
2693    // Unfold backward with overlapping patches
2694    // -----------------------------------------------------------------------
2695
2696    #[test]
2697    fn test_unfold_backward_overlapping() {
2698        // Overlapping unfold (stride < kernel): each input pixel appears in
2699        // multiple patches, so gradient should be > 1 for interior pixels.
2700        let data: Vec<f32> = (0..9).map(|i| i as f32).collect();
2701        let input = leaf_4d(&data, [1, 1, 3, 3], true);
2702        let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
2703        // out shape: [1, 4, 4]
2704
2705        let out_data = out.data().unwrap().to_vec();
2706        let total: f32 = out_data.iter().sum();
2707        let loss = Tensor::from_operation(
2708            TensorStorage::cpu(vec![total]),
2709            vec![],
2710            Arc::new(TestSumBackward { input: out }),
2711        )
2712        .unwrap();
2713        loss.backward().unwrap();
2714
2715        let grad = input.grad().unwrap().unwrap();
2716        let g = grad.data().unwrap();
2717        // Corner pixels appear in 1 patch: grad=1
2718        // Edge pixels appear in 2 patches: grad=2
2719        // Center pixel appears in 4 patches: grad=4
2720        #[rustfmt::skip]
2721        let expected: Vec<f32> = vec![
2722            1.0, 2.0, 1.0,
2723            2.0, 4.0, 2.0,
2724            1.0, 2.0, 1.0,
2725        ];
2726        assert_close(g, &expected, 1e-5);
2727    }
2728
2729    // -----------------------------------------------------------------------
2730    // Multichannel batch tests
2731    // -----------------------------------------------------------------------
2732
2733    #[test]
2734    fn test_interpolate_multichannel_batch() {
2735        // [2, 3, 2, 2] -> [2, 3, 4, 4]
2736        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
2737        let input = leaf_4d(&data, [2, 3, 2, 2], false);
2738        let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
2739        assert_eq!(out.shape(), &[2, 3, 4, 4]);
2740    }
2741
2742    // -----------------------------------------------------------------------
2743    // Helper backward node for tests
2744    // -----------------------------------------------------------------------
2745
2746    #[derive(Debug)]
2747    struct TestSumBackward {
2748        input: Tensor<f32>,
2749    }
2750
2751    impl GradFn<f32> for TestSumBackward {
2752        fn backward(
2753            &self,
2754            _grad_output: &Tensor<f32>,
2755        ) -> FerrotorchResult<Vec<Option<Tensor<f32>>>> {
2756            let ones_data = vec![1.0f32; self.input.numel()];
2757            let ones = Tensor::from_storage(
2758                TensorStorage::cpu(ones_data),
2759                self.input.shape().to_vec(),
2760                false,
2761            )?;
2762            Ok(vec![Some(ones)])
2763        }
2764
2765        fn inputs(&self) -> Vec<&Tensor<f32>> {
2766            vec![&self.input]
2767        }
2768
2769        fn name(&self) -> &'static str {
2770            "TestSumBackward"
2771        }
2772    }
2773}