use std::sync::Arc;
use ferrotorch_core::autograd::no_grad::is_grad_enabled;
use ferrotorch_core::tensor::GradFn;
use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, TensorStorage};
use crate::module::Module;
use crate::parameter::Parameter;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InterpolateMode {
Nearest,
Bilinear,
Bicubic,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GridSamplePaddingMode {
Zeros,
Border,
Reflection,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GridSampleMode {
Bilinear,
Nearest,
}
fn validate_4d<T: Float>(
input: &Tensor<T>,
fn_name: &str,
) -> FerrotorchResult<(usize, usize, usize, usize)> {
let shape = input.shape();
if shape.len() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"{fn_name} expects 4D input [B, C, H, W], got shape {:?}",
shape
),
});
}
Ok((shape[0], shape[1], shape[2], shape[3]))
}
#[inline]
fn cubic_weight(t: f64) -> f64 {
let abs_t = t.abs();
let a: f64 = -0.75;
if abs_t <= 1.0 {
(a + 2.0) * abs_t * abs_t * abs_t - (a + 3.0) * abs_t * abs_t + 1.0
} else if abs_t < 2.0 {
a * abs_t * abs_t * abs_t - 5.0 * a * abs_t * abs_t + 8.0 * a * abs_t - 4.0 * a
} else {
0.0
}
}
#[inline]
fn align_corners_coord(i: usize, in_size: usize, out_size: usize) -> f64 {
if out_size <= 1 {
return 0.0;
}
(i as f64) * ((in_size - 1) as f64) / ((out_size - 1) as f64)
}
#[inline]
fn half_pixel_coord(i: usize, in_size: usize, out_size: usize) -> f64 {
(i as f64 + 0.5) * (in_size as f64 / out_size as f64) - 0.5
}
#[inline]
fn clamp_coord(val: isize, max: usize) -> usize {
if val < 0 {
0
} else if val as usize > max {
max
} else {
val as usize
}
}
pub fn interpolate<T: Float>(
input: &Tensor<T>,
size: Option<[usize; 2]>,
scale_factor: Option<[f64; 2]>,
mode: InterpolateMode,
align_corners: bool,
) -> FerrotorchResult<Tensor<T>> {
let (batch, channels, h_in, w_in) = validate_4d(input, "interpolate")?;
let (h_out, w_out) = match (size, scale_factor) {
(Some(s), None) => (s[0], s[1]),
(None, Some(sf)) => {
let h = (h_in as f64 * sf[0]).round() as usize;
let w = (w_in as f64 * sf[1]).round() as usize;
if h == 0 || w == 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"interpolate: scale_factor {sf:?} with input ({h_in}, {w_in}) produces zero output"
),
});
}
(h, w)
}
_ => {
return Err(FerrotorchError::InvalidArgument {
message: "interpolate: exactly one of size or scale_factor must be provided".into(),
});
}
};
if h_out == 0 || w_out == 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("interpolate: output size ({h_out}, {w_out}) must be > 0"),
});
}
if mode == InterpolateMode::Nearest && align_corners {
return Err(FerrotorchError::InvalidArgument {
message: "interpolate: align_corners is not supported with nearest mode".into(),
});
}
let input_device = input.device();
let data = input.data_vec()?;
let total = batch * channels * h_out * w_out;
let mut output = vec![T::from(0.0).unwrap(); total];
match mode {
InterpolateMode::Nearest => {
nearest_forward(
&data,
&mut output,
batch,
channels,
h_in,
w_in,
h_out,
w_out,
);
}
InterpolateMode::Bilinear => {
bilinear_forward(
&data,
&mut output,
batch,
channels,
h_in,
w_in,
h_out,
w_out,
align_corners,
);
}
InterpolateMode::Bicubic => {
bicubic_forward(
&data,
&mut output,
batch,
channels,
h_in,
w_in,
h_out,
w_out,
align_corners,
);
}
}
let out_shape = vec![batch, channels, h_out, w_out];
let storage = TensorStorage::cpu(output);
if is_grad_enabled() && input.requires_grad() {
Tensor::from_operation(
storage,
out_shape,
Arc::new(InterpolateBackward {
input: input.clone(),
h_out,
w_out,
mode,
align_corners,
}),
)?
.to(input_device)
} else {
Tensor::from_storage(storage, out_shape, false)?.to(input_device)
}
}
#[allow(clippy::too_many_arguments)]
fn nearest_forward<T: Float>(
data: &[T],
output: &mut [T],
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
h_out: usize,
w_out: usize,
) {
let h_scale = h_in as f64 / h_out as f64;
let w_scale = w_in as f64 / w_out as f64;
for b in 0..batch {
for c in 0..channels {
for oh in 0..h_out {
let ih = ((oh as f64 * h_scale).floor() as usize).min(h_in - 1);
for ow in 0..w_out {
let iw = ((ow as f64 * w_scale).floor() as usize).min(w_in - 1);
let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
let in_idx = ((b * channels + c) * h_in + ih) * w_in + iw;
output[out_idx] = data[in_idx];
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn bilinear_forward<T: Float>(
data: &[T],
output: &mut [T],
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
h_out: usize,
w_out: usize,
align_corners: bool,
) {
let one = T::from(1.0).unwrap();
for b in 0..batch {
for c in 0..channels {
for oh in 0..h_out {
let src_h = if align_corners {
align_corners_coord(oh, h_in, h_out)
} else {
half_pixel_coord(oh, h_in, h_out)
};
let h0 = src_h.floor() as isize;
let h1 = h0 + 1;
let th = T::from(src_h - h0 as f64).unwrap();
for ow in 0..w_out {
let src_w = if align_corners {
align_corners_coord(ow, w_in, w_out)
} else {
half_pixel_coord(ow, w_in, w_out)
};
let w0 = src_w.floor() as isize;
let w1 = w0 + 1;
let tw = T::from(src_w - w0 as f64).unwrap();
let ch0 = clamp_coord(h0, h_in - 1);
let ch1 = clamp_coord(h1, h_in - 1);
let cw0 = clamp_coord(w0, w_in - 1);
let cw1 = clamp_coord(w1, w_in - 1);
let base = (b * channels + c) * h_in;
let v00 = data[(base + ch0) * w_in + cw0];
let v01 = data[(base + ch0) * w_in + cw1];
let v10 = data[(base + ch1) * w_in + cw0];
let v11 = data[(base + ch1) * w_in + cw1];
let val = v00 * (one - th) * (one - tw)
+ v01 * (one - th) * tw
+ v10 * th * (one - tw)
+ v11 * th * tw;
let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
output[out_idx] = val;
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn bicubic_forward<T: Float>(
data: &[T],
output: &mut [T],
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
h_out: usize,
w_out: usize,
align_corners: bool,
) {
for b in 0..batch {
for c in 0..channels {
for oh in 0..h_out {
let src_h = if align_corners {
align_corners_coord(oh, h_in, h_out)
} else {
half_pixel_coord(oh, h_in, h_out)
};
let h_floor = src_h.floor() as isize;
let frac_h = src_h - h_floor as f64;
let wh: [T; 4] = [
T::from(cubic_weight(frac_h + 1.0)).unwrap(),
T::from(cubic_weight(frac_h)).unwrap(),
T::from(cubic_weight(frac_h - 1.0)).unwrap(),
T::from(cubic_weight(frac_h - 2.0)).unwrap(),
];
for ow in 0..w_out {
let src_w = if align_corners {
align_corners_coord(ow, w_in, w_out)
} else {
half_pixel_coord(ow, w_in, w_out)
};
let w_floor = src_w.floor() as isize;
let frac_w = src_w - w_floor as f64;
let ww: [T; 4] = [
T::from(cubic_weight(frac_w + 1.0)).unwrap(),
T::from(cubic_weight(frac_w)).unwrap(),
T::from(cubic_weight(frac_w - 1.0)).unwrap(),
T::from(cubic_weight(frac_w - 2.0)).unwrap(),
];
let mut val = T::from(0.0).unwrap();
let base = (b * channels + c) * h_in;
for dy in 0..4isize {
let iy = clamp_coord(h_floor - 1 + dy, h_in - 1);
for dx in 0..4isize {
let ix = clamp_coord(w_floor - 1 + dx, w_in - 1);
let pixel = data[(base + iy) * w_in + ix];
val += pixel * wh[dy as usize] * ww[dx as usize];
}
}
let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
output[out_idx] = val;
}
}
}
}
}
#[derive(Debug)]
struct InterpolateBackward<T: Float> {
input: Tensor<T>,
h_out: usize,
w_out: usize,
mode: InterpolateMode,
align_corners: bool,
}
impl<T: Float> GradFn<T> for InterpolateBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let in_shape = self.input.shape();
let (batch, channels, h_in, w_in) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
let h_out = self.h_out;
let w_out = self.w_out;
let go_data = grad_output.data_vec()?;
let mut grad_input = vec![T::from(0.0).unwrap(); batch * channels * h_in * w_in];
match self.mode {
InterpolateMode::Nearest => {
nearest_backward(
&go_data,
&mut grad_input,
batch,
channels,
h_in,
w_in,
h_out,
w_out,
);
}
InterpolateMode::Bilinear => {
bilinear_backward(
&go_data,
&mut grad_input,
batch,
channels,
h_in,
w_in,
h_out,
w_out,
self.align_corners,
);
}
InterpolateMode::Bicubic => {
bicubic_backward(
&go_data,
&mut grad_input,
batch,
channels,
h_in,
w_in,
h_out,
w_out,
self.align_corners,
);
}
}
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_input),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"InterpolateBackward"
}
}
#[allow(clippy::too_many_arguments)]
fn nearest_backward<T: Float>(
go: &[T],
grad_input: &mut [T],
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
h_out: usize,
w_out: usize,
) {
let h_scale = h_in as f64 / h_out as f64;
let w_scale = w_in as f64 / w_out as f64;
for b in 0..batch {
for c in 0..channels {
for oh in 0..h_out {
let ih = ((oh as f64 * h_scale).floor() as usize).min(h_in - 1);
for ow in 0..w_out {
let iw = ((ow as f64 * w_scale).floor() as usize).min(w_in - 1);
let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
let in_idx = ((b * channels + c) * h_in + ih) * w_in + iw;
grad_input[in_idx] += go[out_idx];
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn bilinear_backward<T: Float>(
go: &[T],
grad_input: &mut [T],
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
h_out: usize,
w_out: usize,
align_corners: bool,
) {
let one = T::from(1.0).unwrap();
for b in 0..batch {
for c in 0..channels {
for oh in 0..h_out {
let src_h = if align_corners {
align_corners_coord(oh, h_in, h_out)
} else {
half_pixel_coord(oh, h_in, h_out)
};
let h0 = src_h.floor() as isize;
let h1 = h0 + 1;
let th = T::from(src_h - h0 as f64).unwrap();
for ow in 0..w_out {
let src_w = if align_corners {
align_corners_coord(ow, w_in, w_out)
} else {
half_pixel_coord(ow, w_in, w_out)
};
let w0 = src_w.floor() as isize;
let w1 = w0 + 1;
let tw = T::from(src_w - w0 as f64).unwrap();
let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
let g = go[out_idx];
let ch0 = clamp_coord(h0, h_in - 1);
let ch1 = clamp_coord(h1, h_in - 1);
let cw0 = clamp_coord(w0, w_in - 1);
let cw1 = clamp_coord(w1, w_in - 1);
let base = (b * channels + c) * h_in;
grad_input[(base + ch0) * w_in + cw0] += g * (one - th) * (one - tw);
grad_input[(base + ch0) * w_in + cw1] += g * (one - th) * tw;
grad_input[(base + ch1) * w_in + cw0] += g * th * (one - tw);
grad_input[(base + ch1) * w_in + cw1] += g * th * tw;
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
fn bicubic_backward<T: Float>(
go: &[T],
grad_input: &mut [T],
batch: usize,
channels: usize,
h_in: usize,
w_in: usize,
h_out: usize,
w_out: usize,
align_corners: bool,
) {
for b in 0..batch {
for c in 0..channels {
for oh in 0..h_out {
let src_h: f64 = if align_corners {
align_corners_coord(oh, h_in, h_out)
} else {
half_pixel_coord(oh, h_in, h_out)
};
let h_floor = src_h.floor() as isize;
let frac_h = src_h - h_floor as f64;
let wh: [T; 4] = [
T::from(cubic_weight(frac_h + 1.0)).unwrap(),
T::from(cubic_weight(frac_h)).unwrap(),
T::from(cubic_weight(frac_h - 1.0)).unwrap(),
T::from(cubic_weight(frac_h - 2.0)).unwrap(),
];
for ow in 0..w_out {
let src_w: f64 = if align_corners {
align_corners_coord(ow, w_in, w_out)
} else {
half_pixel_coord(ow, w_in, w_out)
};
let w_floor = src_w.floor() as isize;
let frac_w = src_w - w_floor as f64;
let ww: [T; 4] = [
T::from(cubic_weight(frac_w + 1.0)).unwrap(),
T::from(cubic_weight(frac_w)).unwrap(),
T::from(cubic_weight(frac_w - 1.0)).unwrap(),
T::from(cubic_weight(frac_w - 2.0)).unwrap(),
];
let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
let g = go[out_idx];
let base = (b * channels + c) * h_in;
for dy in 0..4isize {
let iy = clamp_coord(h_floor - 1 + dy, h_in - 1);
for dx in 0..4isize {
let ix = clamp_coord(w_floor - 1 + dx, w_in - 1);
grad_input[(base + iy) * w_in + ix] +=
g * wh[dy as usize] * ww[dx as usize];
}
}
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct Upsample {
pub size: Option<[usize; 2]>,
pub scale_factor: Option<[f64; 2]>,
pub mode: InterpolateMode,
pub align_corners: bool,
}
impl Upsample {
pub fn new(size: [usize; 2], mode: InterpolateMode) -> Self {
Self {
size: Some(size),
scale_factor: None,
mode,
align_corners: false,
}
}
pub fn with_scale_factor(scale_factor: [f64; 2], mode: InterpolateMode) -> Self {
Self {
size: None,
scale_factor: Some(scale_factor),
mode,
align_corners: false,
}
}
pub fn align_corners(mut self, align_corners: bool) -> Self {
self.align_corners = align_corners;
self
}
}
impl<T: Float> Module<T> for Upsample {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
interpolate(
input,
self.size,
self.scale_factor,
self.mode,
self.align_corners,
)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![]
}
fn train(&mut self) {}
fn eval(&mut self) {}
fn is_training(&self) -> bool {
false
}
}
pub fn grid_sample<T: Float>(
input: &Tensor<T>,
grid: &Tensor<T>,
mode: GridSampleMode,
padding_mode: GridSamplePaddingMode,
align_corners: bool,
) -> FerrotorchResult<Tensor<T>> {
let (batch, channels, h_in, w_in) = validate_4d(input, "grid_sample")?;
let grid_shape = grid.shape();
if grid_shape.len() != 4 || grid_shape[3] != 2 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"grid_sample: grid must be [B, H_out, W_out, 2], got {:?}",
grid_shape
),
});
}
if grid_shape[0] != batch {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"grid_sample: batch mismatch between input ({batch}) and grid ({})",
grid_shape[0]
),
});
}
let h_out = grid_shape[1];
let w_out = grid_shape[2];
let input_device = input.device();
let in_data = input.data_vec()?;
let grid_data = grid.data_vec()?;
let total = batch * channels * h_out * w_out;
let mut output = vec![T::from(0.0).unwrap(); total];
let one = T::from(1.0).unwrap();
let two = T::from(2.0).unwrap();
let zero = T::from(0.0).unwrap();
for b in 0..batch {
for oh in 0..h_out {
for ow in 0..w_out {
let grid_base = ((b * h_out + oh) * w_out + ow) * 2;
let gx = grid_data[grid_base]; let gy = grid_data[grid_base + 1];
let (src_x, src_y) = if align_corners {
let sx = (gx + one) * T::from(w_in - 1).unwrap() / two;
let sy = (gy + one) * T::from(h_in - 1).unwrap() / two;
(sx, sy)
} else {
let sx = ((gx + one) * T::from(w_in).unwrap() - one) / two;
let sy = ((gy + one) * T::from(h_in).unwrap() - one) / two;
(sx, sy)
};
for c in 0..channels {
let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
let in_base = (b * channels + c) * h_in;
match mode {
GridSampleMode::Nearest => {
let ix = src_x.to_f64().unwrap().round() as isize;
let iy = src_y.to_f64().unwrap().round() as isize;
let (ix, iy) = apply_padding_mode(ix, iy, w_in, h_in, padding_mode);
if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
output[out_idx] =
in_data[(in_base + iy as usize) * w_in + ix as usize];
}
}
GridSampleMode::Bilinear => {
let sx = src_x.to_f64().unwrap();
let sy = src_y.to_f64().unwrap();
let x0 = sx.floor() as isize;
let y0 = sy.floor() as isize;
let x1 = x0 + 1;
let y1 = y0 + 1;
let tx = T::from(sx - x0 as f64).unwrap();
let ty = T::from(sy - y0 as f64).unwrap();
let get_pixel = |iy: isize, ix: isize| -> T {
let (ix, iy) = apply_padding_mode(ix, iy, w_in, h_in, padding_mode);
if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
in_data[(in_base + iy as usize) * w_in + ix as usize]
} else {
zero
}
};
let v00 = get_pixel(y0, x0);
let v01 = get_pixel(y0, x1);
let v10 = get_pixel(y1, x0);
let v11 = get_pixel(y1, x1);
output[out_idx] = v00 * (one - ty) * (one - tx)
+ v01 * (one - ty) * tx
+ v10 * ty * (one - tx)
+ v11 * ty * tx;
}
}
}
}
}
}
let out_shape = vec![batch, channels, h_out, w_out];
let storage = TensorStorage::cpu(output);
if is_grad_enabled() && (input.requires_grad() || grid.requires_grad()) {
Tensor::from_operation(
storage,
out_shape,
Arc::new(GridSampleBackward {
input: input.clone(),
grid: grid.clone(),
mode,
padding_mode,
align_corners,
}),
)?
.to(input_device)
} else {
Tensor::from_storage(storage, out_shape, false)?.to(input_device)
}
}
fn apply_padding_mode(
ix: isize,
iy: isize,
w: usize,
h: usize,
padding_mode: GridSamplePaddingMode,
) -> (isize, isize) {
match padding_mode {
GridSamplePaddingMode::Zeros => (ix, iy),
GridSamplePaddingMode::Border => {
let cx = ix.max(0).min(w as isize - 1);
let cy = iy.max(0).min(h as isize - 1);
(cx, cy)
}
GridSamplePaddingMode::Reflection => {
let reflect = |v: isize, size: usize| -> isize {
if size <= 1 {
return 0;
}
let max = size as isize - 1;
let mut v = v;
if v < 0 {
v = -v;
}
let period = 2 * max;
v %= period;
if v > max {
v = period - v;
}
v
};
(reflect(ix, w), reflect(iy, h))
}
}
}
#[derive(Debug)]
struct GridSampleBackward<T: Float> {
input: Tensor<T>,
grid: Tensor<T>,
mode: GridSampleMode,
padding_mode: GridSamplePaddingMode,
align_corners: bool,
}
impl<T: Float> GradFn<T> for GridSampleBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let in_shape = self.input.shape();
let (batch, channels, h_in, w_in) = (in_shape[0], in_shape[1], in_shape[2], in_shape[3]);
let grid_shape = self.grid.shape();
let h_out = grid_shape[1];
let w_out = grid_shape[2];
let go_data = grad_output.data_vec()?;
let in_data = self.input.data_vec()?;
let grid_data = self.grid.data_vec()?;
let one = T::from(1.0).unwrap();
let two = T::from(2.0).unwrap();
let zero = T::from(0.0).unwrap();
let grad_input_needed = self.input.requires_grad();
let grad_grid_needed = self.grid.requires_grad();
let mut grad_input = if grad_input_needed {
vec![zero; batch * channels * h_in * w_in]
} else {
vec![]
};
let mut grad_grid = if grad_grid_needed {
vec![zero; batch * h_out * w_out * 2]
} else {
vec![]
};
for b in 0..batch {
for oh in 0..h_out {
for ow in 0..w_out {
let grid_base = ((b * h_out + oh) * w_out + ow) * 2;
let gx = grid_data[grid_base];
let gy = grid_data[grid_base + 1];
let (src_x, src_y) = if self.align_corners {
let sx = (gx + one) * T::from(w_in - 1).unwrap() / two;
let sy = (gy + one) * T::from(h_in - 1).unwrap() / two;
(sx, sy)
} else {
let sx = ((gx + one) * T::from(w_in).unwrap() - one) / two;
let sy = ((gy + one) * T::from(h_in).unwrap() - one) / two;
(sx, sy)
};
match self.mode {
GridSampleMode::Bilinear => {
let sx = src_x.to_f64().unwrap();
let sy = src_y.to_f64().unwrap();
let x0 = sx.floor() as isize;
let y0 = sy.floor() as isize;
let x1 = x0 + 1;
let y1 = y0 + 1;
let tx = T::from(sx - x0 as f64).unwrap();
let ty = T::from(sy - y0 as f64).unwrap();
let get_clamped = |iy: isize, ix: isize| -> (isize, isize) {
apply_padding_mode(ix, iy, w_in, h_in, self.padding_mode)
};
for c in 0..channels {
let out_idx = ((b * channels + c) * h_out + oh) * w_out + ow;
let g = go_data[out_idx];
let in_base = (b * channels + c) * h_in;
if grad_input_needed {
let coords = [
(y0, x0, (one - ty) * (one - tx)),
(y0, x1, (one - ty) * tx),
(y1, x0, ty * (one - tx)),
(y1, x1, ty * tx),
];
for (iy, ix, w) in coords {
let (ix, iy) = get_clamped(iy, ix);
if ix >= 0
&& ix < w_in as isize
&& iy >= 0
&& iy < h_in as isize
{
grad_input
[(in_base + iy as usize) * w_in + ix as usize] +=
g * w;
}
}
}
if grad_grid_needed {
let get_pixel = |iy: isize, ix: isize| -> T {
let (ix, iy) = get_clamped(iy, ix);
if ix >= 0
&& ix < w_in as isize
&& iy >= 0
&& iy < h_in as isize
{
in_data[(in_base + iy as usize) * w_in + ix as usize]
} else {
zero
}
};
let v00 = get_pixel(y0, x0);
let v01 = get_pixel(y0, x1);
let v10 = get_pixel(y1, x0);
let v11 = get_pixel(y1, x1);
let dout_dsx = (one - ty) * (v01 - v00) + ty * (v11 - v10);
let dout_dsy = (one - tx) * (v10 - v00) + tx * (v11 - v01);
let dsx_dgx = if self.align_corners {
T::from(w_in - 1).unwrap() / two
} else {
T::from(w_in).unwrap() / two
};
let dsy_dgy = if self.align_corners {
T::from(h_in - 1).unwrap() / two
} else {
T::from(h_in).unwrap() / two
};
grad_grid[grid_base] += g * dout_dsx * dsx_dgx;
grad_grid[grid_base + 1] += g * dout_dsy * dsy_dgy;
}
}
}
GridSampleMode::Nearest => {
if grad_input_needed {
let ix = src_x.to_f64().unwrap().round() as isize;
let iy = src_y.to_f64().unwrap().round() as isize;
let (ix, iy) =
apply_padding_mode(ix, iy, w_in, h_in, self.padding_mode);
if ix >= 0 && ix < w_in as isize && iy >= 0 && iy < h_in as isize {
for c in 0..channels {
let out_idx =
((b * channels + c) * h_out + oh) * w_out + ow;
let in_base = (b * channels + c) * h_in;
grad_input[(in_base + iy as usize) * w_in + ix as usize] +=
go_data[out_idx];
}
}
}
}
}
}
}
}
let gi = if grad_input_needed {
Some(Tensor::from_storage(
TensorStorage::cpu(grad_input),
self.input.shape().to_vec(),
false,
)?)
} else {
None
};
let gg = if grad_grid_needed {
Some(Tensor::from_storage(
TensorStorage::cpu(grad_grid),
self.grid.shape().to_vec(),
false,
)?)
} else {
None
};
Ok(vec![gi, gg])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input, &self.grid]
}
fn name(&self) -> &'static str {
"GridSampleBackward"
}
}
pub fn affine_grid<T: Float>(
theta: &Tensor<T>,
size: [usize; 4],
align_corners: bool,
) -> FerrotorchResult<Tensor<T>> {
let theta_shape = theta.shape();
if theta_shape.len() != 3 || theta_shape[1] != 2 || theta_shape[2] != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"affine_grid: theta must be [B, 2, 3], got {:?}",
theta_shape
),
});
}
let batch = theta_shape[0];
if size[0] != batch {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"affine_grid: batch mismatch: theta batch {batch}, size batch {}",
size[0]
),
});
}
let h = size[2];
let w = size[3];
let one = T::from(1.0).unwrap();
let two = T::from(2.0).unwrap();
let theta_data = theta.data_vec()?;
let theta_device = theta.device();
let total = batch * h * w * 2;
let mut grid = vec![T::from(0.0).unwrap(); total];
for b in 0..batch {
let t_base = b * 6;
let t00 = theta_data[t_base];
let t01 = theta_data[t_base + 1];
let t02 = theta_data[t_base + 2];
let t10 = theta_data[t_base + 3];
let t11 = theta_data[t_base + 4];
let t12 = theta_data[t_base + 5];
for iy in 0..h {
let y_norm = if align_corners {
if h <= 1 {
T::from(0.0).unwrap()
} else {
two * T::from(iy).unwrap() / T::from(h - 1).unwrap() - one
}
} else {
(two * T::from(iy).unwrap() + one) / T::from(h).unwrap() - one
};
for ix in 0..w {
let x_norm = if align_corners {
if w <= 1 {
T::from(0.0).unwrap()
} else {
two * T::from(ix).unwrap() / T::from(w - 1).unwrap() - one
}
} else {
(two * T::from(ix).unwrap() + one) / T::from(w).unwrap() - one
};
let out_base = ((b * h + iy) * w + ix) * 2;
grid[out_base] = t00 * x_norm + t01 * y_norm + t02;
grid[out_base + 1] = t10 * x_norm + t11 * y_norm + t12;
}
}
}
let out_shape = vec![batch, h, w, 2];
let storage = TensorStorage::cpu(grid);
if is_grad_enabled() && theta.requires_grad() {
Tensor::from_operation(
storage,
out_shape,
Arc::new(AffineGridBackward {
theta: theta.clone(),
size,
align_corners,
}),
)?
.to(theta_device)
} else {
Tensor::from_storage(storage, out_shape, false)?.to(theta_device)
}
}
#[derive(Debug)]
struct AffineGridBackward<T: Float> {
theta: Tensor<T>,
size: [usize; 4],
align_corners: bool,
}
impl<T: Float> GradFn<T> for AffineGridBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.theta.requires_grad() {
return Ok(vec![None]);
}
let batch = self.size[0];
let h = self.size[2];
let w = self.size[3];
let one = T::from(1.0).unwrap();
let two = T::from(2.0).unwrap();
let zero = T::from(0.0).unwrap();
let go_data = grad_output.data_vec()?;
let mut grad_theta = vec![zero; batch * 6];
for b in 0..batch {
for iy in 0..h {
let y_norm = if self.align_corners {
if h <= 1 {
zero
} else {
two * T::from(iy).unwrap() / T::from(h - 1).unwrap() - one
}
} else {
(two * T::from(iy).unwrap() + one) / T::from(h).unwrap() - one
};
for ix in 0..w {
let x_norm = if self.align_corners {
if w <= 1 {
zero
} else {
two * T::from(ix).unwrap() / T::from(w - 1).unwrap() - one
}
} else {
(two * T::from(ix).unwrap() + one) / T::from(w).unwrap() - one
};
let go_base = ((b * h + iy) * w + ix) * 2;
let gx = go_data[go_base];
let gy = go_data[go_base + 1];
let t_base = b * 6;
grad_theta[t_base] += gx * x_norm;
grad_theta[t_base + 1] += gx * y_norm;
grad_theta[t_base + 2] += gx;
grad_theta[t_base + 3] += gy * x_norm;
grad_theta[t_base + 4] += gy * y_norm;
grad_theta[t_base + 5] += gy;
}
}
}
let grad_tensor = Tensor::from_storage(
TensorStorage::cpu(grad_theta),
self.theta.shape().to_vec(),
false,
)?;
Ok(vec![Some(grad_tensor)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.theta]
}
fn name(&self) -> &'static str {
"AffineGridBackward"
}
}
#[derive(Debug, Clone, Copy)]
pub struct PixelShuffle {
pub upscale_factor: usize,
}
impl PixelShuffle {
pub fn new(upscale_factor: usize) -> Self {
Self { upscale_factor }
}
}
impl<T: Float> Module<T> for PixelShuffle {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
pixel_shuffle(input, self.upscale_factor)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![]
}
fn train(&mut self) {}
fn eval(&mut self) {}
fn is_training(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Copy)]
pub struct PixelUnshuffle {
pub downscale_factor: usize,
}
impl PixelUnshuffle {
pub fn new(downscale_factor: usize) -> Self {
Self { downscale_factor }
}
}
impl<T: Float> Module<T> for PixelUnshuffle {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
pixel_unshuffle(input, self.downscale_factor)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![]
}
fn train(&mut self) {}
fn eval(&mut self) {}
fn is_training(&self) -> bool {
false
}
}
pub fn pixel_shuffle<T: Float>(
input: &Tensor<T>,
upscale_factor: usize,
) -> FerrotorchResult<Tensor<T>> {
let (batch, channels_in, h, w) = validate_4d(input, "pixel_shuffle")?;
let r = upscale_factor;
if r == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "pixel_shuffle: upscale_factor must be > 0".into(),
});
}
if channels_in % (r * r) != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"pixel_shuffle: channels ({channels_in}) must be divisible by r^2 ({})",
r * r
),
});
}
let c_out = channels_in / (r * r);
let h_out = h * r;
let w_out = w * r;
let input_device = input.device();
let data = input.data_vec()?;
let total = batch * c_out * h_out * w_out;
let mut output = vec![T::from(0.0).unwrap(); total];
for b in 0..batch {
for c in 0..c_out {
for ih in 0..h {
for iw in 0..w {
for rh in 0..r {
for rw in 0..r {
let in_c = c * r * r + rh * r + rw;
let in_idx = ((b * channels_in + in_c) * h + ih) * w + iw;
let oh = ih * r + rh;
let ow_pos = iw * r + rw;
let out_idx = ((b * c_out + c) * h_out + oh) * w_out + ow_pos;
output[out_idx] = data[in_idx];
}
}
}
}
}
}
let out_shape = vec![batch, c_out, h_out, w_out];
let storage = TensorStorage::cpu(output);
if is_grad_enabled() && input.requires_grad() {
Tensor::from_operation(
storage,
out_shape,
Arc::new(PixelShuffleBackward {
input: input.clone(),
upscale_factor: r,
}),
)?
.to(input_device)
} else {
Tensor::from_storage(storage, out_shape, false)?.to(input_device)
}
}
pub fn pixel_unshuffle<T: Float>(
input: &Tensor<T>,
downscale_factor: usize,
) -> FerrotorchResult<Tensor<T>> {
let (batch, channels, h_in, w_in) = validate_4d(input, "pixel_unshuffle")?;
let r = downscale_factor;
if r == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "pixel_unshuffle: downscale_factor must be > 0".into(),
});
}
if h_in % r != 0 || w_in % r != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"pixel_unshuffle: spatial dims ({h_in}, {w_in}) must be divisible by r={r}"
),
});
}
let h_out = h_in / r;
let w_out = w_in / r;
let c_out = channels * r * r;
let input_device = input.device();
let data = input.data_vec()?;
let total = batch * c_out * h_out * w_out;
let mut output = vec![T::from(0.0).unwrap(); total];
for b in 0..batch {
for c in 0..channels {
for oh in 0..h_out {
for ow in 0..w_out {
for rh in 0..r {
for rw in 0..r {
let in_h = oh * r + rh;
let in_w = ow * r + rw;
let in_idx = ((b * channels + c) * h_in + in_h) * w_in + in_w;
let out_c = c * r * r + rh * r + rw;
let out_idx = ((b * c_out + out_c) * h_out + oh) * w_out + ow;
output[out_idx] = data[in_idx];
}
}
}
}
}
}
let out_shape = vec![batch, c_out, h_out, w_out];
let storage = TensorStorage::cpu(output);
if is_grad_enabled() && input.requires_grad() {
Tensor::from_operation(
storage,
out_shape,
Arc::new(PixelUnshuffleBackward {
input: input.clone(),
downscale_factor: r,
}),
)?
.to(input_device)
} else {
Tensor::from_storage(storage, out_shape, false)?.to(input_device)
}
}
#[derive(Debug)]
struct PixelShuffleBackward<T: Float> {
input: Tensor<T>,
upscale_factor: usize,
}
impl<T: Float> GradFn<T> for PixelShuffleBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let grad_input = pixel_unshuffle(grad_output, self.upscale_factor)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"PixelShuffleBackward"
}
}
#[derive(Debug)]
struct PixelUnshuffleBackward<T: Float> {
input: Tensor<T>,
downscale_factor: usize,
}
impl<T: Float> GradFn<T> for PixelUnshuffleBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let grad_input = pixel_shuffle(grad_output, self.downscale_factor)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"PixelUnshuffleBackward"
}
}
#[derive(Debug, Clone, Copy)]
pub struct Unfold {
pub kernel_size: [usize; 2],
pub dilation: [usize; 2],
pub padding: [usize; 2],
pub stride: [usize; 2],
}
impl Unfold {
pub fn new(
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: [usize; 2],
stride: [usize; 2],
) -> Self {
Self {
kernel_size,
dilation,
padding,
stride,
}
}
}
impl<T: Float> Module<T> for Unfold {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
unfold(
input,
self.kernel_size,
self.dilation,
self.padding,
self.stride,
)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![]
}
fn train(&mut self) {}
fn eval(&mut self) {}
fn is_training(&self) -> bool {
false
}
}
#[derive(Debug, Clone, Copy)]
pub struct Fold {
pub output_size: [usize; 2],
pub kernel_size: [usize; 2],
pub dilation: [usize; 2],
pub padding: [usize; 2],
pub stride: [usize; 2],
}
impl Fold {
pub fn new(
output_size: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: [usize; 2],
stride: [usize; 2],
) -> Self {
Self {
output_size,
kernel_size,
dilation,
padding,
stride,
}
}
}
impl<T: Float> Module<T> for Fold {
fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
fold(
input,
self.output_size,
self.kernel_size,
self.dilation,
self.padding,
self.stride,
)
}
fn parameters(&self) -> Vec<&Parameter<T>> {
vec![]
}
fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
vec![]
}
fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
vec![]
}
fn train(&mut self) {}
fn eval(&mut self) {}
fn is_training(&self) -> bool {
false
}
}
#[inline]
fn unfold_output_size(
input_size: usize,
kernel_size: usize,
dilation: usize,
padding: usize,
stride: usize,
) -> usize {
(input_size + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1
}
pub fn unfold<T: Float>(
input: &Tensor<T>,
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: [usize; 2],
stride: [usize; 2],
) -> FerrotorchResult<Tensor<T>> {
let (batch, channels, h, w) = validate_4d(input, "unfold")?;
if kernel_size[0] == 0
|| kernel_size[1] == 0
|| stride[0] == 0
|| stride[1] == 0
|| dilation[0] == 0
|| dilation[1] == 0
{
return Err(FerrotorchError::InvalidArgument {
message: "unfold: kernel_size, stride, dilation must all be > 0".into(),
});
}
let out_h = unfold_output_size(h, kernel_size[0], dilation[0], padding[0], stride[0]);
let out_w = unfold_output_size(w, kernel_size[1], dilation[1], padding[1], stride[1]);
let l = out_h * out_w;
let k = channels * kernel_size[0] * kernel_size[1];
let input_device = input.device();
let data = input.data_vec()?;
let total = batch * k * l;
let mut output = vec![T::from(0.0).unwrap(); total];
for b in 0..batch {
for c in 0..channels {
for kh in 0..kernel_size[0] {
for kw in 0..kernel_size[1] {
let k_idx = (c * kernel_size[0] + kh) * kernel_size[1] + kw;
for oh in 0..out_h {
for ow in 0..out_w {
let ih = oh * stride[0] + kh * dilation[0];
let iw = ow * stride[1] + kw * dilation[1];
let ih = ih as isize - padding[0] as isize;
let iw = iw as isize - padding[1] as isize;
let l_idx = oh * out_w + ow;
let out_idx = (b * k + k_idx) * l + l_idx;
if ih >= 0 && ih < h as isize && iw >= 0 && iw < w as isize {
let in_idx =
((b * channels + c) * h + ih as usize) * w + iw as usize;
output[out_idx] = data[in_idx];
}
}
}
}
}
}
}
let out_shape = vec![batch, k, l];
let storage = TensorStorage::cpu(output);
if is_grad_enabled() && input.requires_grad() {
Tensor::from_operation(
storage,
out_shape,
Arc::new(UnfoldBackward {
input: input.clone(),
kernel_size,
dilation,
padding,
stride,
}),
)?
.to(input_device)
} else {
Tensor::from_storage(storage, out_shape, false)?.to(input_device)
}
}
pub fn fold<T: Float>(
input: &Tensor<T>,
output_size: [usize; 2],
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: [usize; 2],
stride: [usize; 2],
) -> FerrotorchResult<Tensor<T>> {
let shape = input.shape();
if shape.len() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"fold expects 3D input [B, C*kH*kW, L], got shape {:?}",
shape
),
});
}
if kernel_size[0] == 0
|| kernel_size[1] == 0
|| stride[0] == 0
|| stride[1] == 0
|| dilation[0] == 0
|| dilation[1] == 0
{
return Err(FerrotorchError::InvalidArgument {
message: "fold: kernel_size, stride, dilation must all be > 0".into(),
});
}
let batch = shape[0];
let k = shape[1]; let l = shape[2];
let [h_out, w_out] = output_size;
let k_area = kernel_size[0] * kernel_size[1];
if k % k_area != 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!("fold: dim 1 ({k}) must be divisible by kH*kW ({})", k_area),
});
}
let channels = k / k_area;
let expected_out_h =
unfold_output_size(h_out, kernel_size[0], dilation[0], padding[0], stride[0]);
let expected_out_w =
unfold_output_size(w_out, kernel_size[1], dilation[1], padding[1], stride[1]);
let expected_l = expected_out_h * expected_out_w;
if l != expected_l {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"fold: L={l} does not match expected {expected_l} for output_size ({h_out}, {w_out})"
),
});
}
let input_device = input.device();
let data = input.data_vec()?;
let total = batch * channels * h_out * w_out;
let mut output = vec![T::from(0.0).unwrap(); total];
for b in 0..batch {
for c in 0..channels {
for kh in 0..kernel_size[0] {
for kw in 0..kernel_size[1] {
let k_idx = (c * kernel_size[0] + kh) * kernel_size[1] + kw;
for oh in 0..expected_out_h {
for ow in 0..expected_out_w {
let ih = oh * stride[0] + kh * dilation[0];
let iw = ow * stride[1] + kw * dilation[1];
let ih = ih as isize - padding[0] as isize;
let iw = iw as isize - padding[1] as isize;
if ih >= 0 && ih < h_out as isize && iw >= 0 && iw < w_out as isize {
let l_idx = oh * expected_out_w + ow;
let in_idx = (b * k + k_idx) * l + l_idx;
let out_idx = ((b * channels + c) * h_out + ih as usize) * w_out
+ iw as usize;
output[out_idx] += data[in_idx];
}
}
}
}
}
}
}
let out_shape = vec![batch, channels, h_out, w_out];
let storage = TensorStorage::cpu(output);
if is_grad_enabled() && input.requires_grad() {
Tensor::from_operation(
storage,
out_shape,
Arc::new(FoldBackward {
input: input.clone(),
kernel_size,
dilation,
padding,
stride,
}),
)?
.to(input_device)
} else {
Tensor::from_storage(storage, out_shape, false)?.to(input_device)
}
}
#[derive(Debug)]
struct UnfoldBackward<T: Float> {
input: Tensor<T>,
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: [usize; 2],
stride: [usize; 2],
}
impl<T: Float> GradFn<T> for UnfoldBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let in_shape = self.input.shape();
let h = in_shape[2];
let w = in_shape[3];
let grad_input = fold(
grad_output,
[h, w],
self.kernel_size,
self.dilation,
self.padding,
self.stride,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"UnfoldBackward"
}
}
#[derive(Debug)]
struct FoldBackward<T: Float> {
input: Tensor<T>,
kernel_size: [usize; 2],
dilation: [usize; 2],
padding: [usize; 2],
stride: [usize; 2],
}
impl<T: Float> GradFn<T> for FoldBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let grad_input = unfold(
grad_output,
self.kernel_size,
self.dilation,
self.padding,
self.stride,
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"FoldBackward"
}
}
#[cfg(test)]
mod tests {
use super::*;
fn leaf_4d(data: &[f32], shape: [usize; 4], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
Tensor::from_storage(
TensorStorage::cpu(data.to_vec()),
shape.to_vec(),
requires_grad,
)
.unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: actual={a} expected={e} diff={}",
(a - e).abs(),
);
}
}
#[test]
fn test_interpolate_nearest_upsample_2x() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 1, 2, 2], false);
let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
assert_eq!(out.shape(), &[1, 1, 4, 4]);
let d = out.data().unwrap();
#[rustfmt::skip]
let expected: Vec<f32> = vec![
1.0, 1.0, 2.0, 2.0,
1.0, 1.0, 2.0, 2.0,
3.0, 3.0, 4.0, 4.0,
3.0, 3.0, 4.0, 4.0,
];
assert_close(d, &expected, 1e-6);
}
#[test]
fn test_interpolate_nearest_downsample() {
#[rustfmt::skip]
let data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let input = leaf_4d(&data, [1, 1, 4, 4], false);
let out = interpolate(&input, Some([2, 2]), None, InterpolateMode::Nearest, false).unwrap();
assert_eq!(out.shape(), &[1, 1, 2, 2]);
let d = out.data().unwrap();
assert_close(d, &[1.0, 3.0, 9.0, 11.0], 1e-6);
}
#[test]
fn test_interpolate_nearest_scale_factor() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 1, 2, 2], false);
let out = interpolate(
&input,
None,
Some([2.0, 2.0]),
InterpolateMode::Nearest,
false,
)
.unwrap();
assert_eq!(out.shape(), &[1, 1, 4, 4]);
}
#[test]
fn test_interpolate_bilinear_upsample() {
let data: Vec<f32> = vec![0.0, 1.0, 2.0, 3.0];
let input = leaf_4d(&data, [1, 1, 2, 2], false);
let out = interpolate(&input, Some([3, 3]), None, InterpolateMode::Bilinear, true).unwrap();
assert_eq!(out.shape(), &[1, 1, 3, 3]);
let d = out.data().unwrap();
assert!((d[0] - 0.0).abs() < 1e-5); assert!((d[2] - 1.0).abs() < 1e-5); assert!((d[6] - 2.0).abs() < 1e-5); assert!((d[8] - 3.0).abs() < 1e-5); assert!((d[4] - 1.5).abs() < 1e-5);
}
#[test]
fn test_interpolate_bilinear_identity() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let input = leaf_4d(&data, [1, 1, 3, 3], false);
let out = interpolate(&input, Some([3, 3]), None, InterpolateMode::Bilinear, true).unwrap();
assert_eq!(out.shape(), &[1, 1, 3, 3]);
assert_close(out.data().unwrap(), &data, 1e-5);
}
#[test]
fn test_interpolate_bicubic_output_shape() {
let data: Vec<f32> = vec![0.0; 64];
let input = leaf_4d(&data, [1, 1, 8, 8], false);
let out = interpolate(
&input,
Some([16, 16]),
None,
InterpolateMode::Bicubic,
false,
)
.unwrap();
assert_eq!(out.shape(), &[1, 1, 16, 16]);
}
#[test]
fn test_interpolate_bicubic_corners_align() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 1, 2, 2], false);
let out = interpolate(&input, Some([5, 5]), None, InterpolateMode::Bicubic, true).unwrap();
assert_eq!(out.shape(), &[1, 1, 5, 5]);
let d = out.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-4); assert!((d[4] - 2.0).abs() < 1e-4); assert!((d[20] - 3.0).abs() < 1e-4); assert!((d[24] - 4.0).abs() < 1e-4); }
#[test]
fn test_upsample_module_nearest() {
let up = Upsample::new([6, 6], InterpolateMode::Nearest);
let input = leaf_4d(&[0.0; 9], [1, 1, 3, 3], false);
let out: Tensor<f32> = Module::<f32>::forward(&up, &input).unwrap();
assert_eq!(out.shape(), &[1, 1, 6, 6]);
}
#[test]
fn test_upsample_module_bilinear_scale() {
let up = Upsample::with_scale_factor([2.0, 2.0], InterpolateMode::Bilinear);
let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
let out: Tensor<f32> = Module::<f32>::forward(&up, &input).unwrap();
assert_eq!(out.shape(), &[1, 1, 4, 4]);
}
#[test]
fn test_upsample_no_parameters() {
let up = Upsample::new([4, 4], InterpolateMode::Nearest);
assert!(Module::<f32>::parameters(&up).is_empty());
}
#[test]
fn test_interpolate_no_size_no_scale() {
let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
assert!(interpolate(&input, None, None, InterpolateMode::Nearest, false).is_err());
}
#[test]
fn test_interpolate_both_size_and_scale() {
let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
assert!(
interpolate(
&input,
Some([4, 4]),
Some([2.0, 2.0]),
InterpolateMode::Nearest,
false
)
.is_err()
);
}
#[test]
fn test_interpolate_nearest_align_corners_rejected() {
let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
assert!(interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, true).is_err());
}
#[test]
fn test_interpolate_3d_rejected() {
let input = leaf(&[0.0; 6], &[2, 3], false);
assert!(interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).is_err());
}
#[test]
fn test_interpolate_nearest_backward() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 1, 2, 2], true);
let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
let out_data = out.data().unwrap().to_vec();
let total: f32 = out_data.iter().sum();
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(TestSumBackward { input: out }),
)
.unwrap();
loss.backward().unwrap();
let grad = input.grad().unwrap().unwrap();
let g = grad.data().unwrap();
for (i, &val) in g.iter().enumerate() {
assert!(
(val - 4.0).abs() < 1e-5,
"grad[{i}]: expected 4.0, got {val}"
);
}
}
#[test]
fn test_interpolate_bilinear_backward() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 1, 2, 2], true);
let out =
interpolate(&input, Some([4, 4]), None, InterpolateMode::Bilinear, false).unwrap();
let out_data = out.data().unwrap().to_vec();
let total: f32 = out_data.iter().sum();
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(TestSumBackward { input: out }),
)
.unwrap();
loss.backward().unwrap();
let grad = input.grad().unwrap().unwrap();
let g = grad.data().unwrap();
let grad_sum: f32 = g.iter().sum();
assert!(
(grad_sum - 16.0).abs() < 1e-3,
"grad sum = {grad_sum}, expected 16.0"
);
}
#[test]
fn test_pixel_shuffle_shape() {
let data = vec![0.0f32; 16];
let input = leaf_4d(&data, [1, 4, 2, 2], false);
let out = pixel_shuffle(&input, 2).unwrap();
assert_eq!(out.shape(), &[1, 1, 4, 4]);
}
#[test]
fn test_pixel_shuffle_values() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 4, 1, 1], false);
let out = pixel_shuffle(&input, 2).unwrap();
assert_eq!(out.shape(), &[1, 1, 2, 2]);
assert_close(out.data().unwrap(), &[1.0, 2.0, 3.0, 4.0], 1e-6);
}
#[test]
fn test_pixel_shuffle_not_divisible() {
let input = leaf_4d(&[0.0; 12], [1, 3, 2, 2], false);
assert!(pixel_shuffle(&input, 2).is_err());
}
#[test]
fn test_pixel_unshuffle_shape() {
let data = vec![0.0f32; 16];
let input = leaf_4d(&data, [1, 1, 4, 4], false);
let out = pixel_unshuffle(&input, 2).unwrap();
assert_eq!(out.shape(), &[1, 4, 2, 2]);
}
#[test]
fn test_pixel_shuffle_unshuffle_roundtrip() {
let data: Vec<f32> = (0..36).map(|i| i as f32).collect();
let input = leaf_4d(&data, [1, 4, 3, 3], false);
let shuffled = pixel_shuffle(&input, 2).unwrap();
assert_eq!(shuffled.shape(), &[1, 1, 6, 6]);
let roundtrip = pixel_unshuffle(&shuffled, 2).unwrap();
assert_eq!(roundtrip.shape(), &[1, 4, 3, 3]);
assert_close(roundtrip.data().unwrap(), &data, 1e-6);
}
#[test]
fn test_pixel_unshuffle_spatial_not_divisible() {
let input = leaf_4d(&[0.0; 9], [1, 1, 3, 3], false);
assert!(pixel_unshuffle(&input, 2).is_err());
}
#[test]
fn test_pixel_shuffle_backward() {
let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
let input = leaf_4d(&data, [1, 4, 2, 2], true);
let out = pixel_shuffle(&input, 2).unwrap();
let out_data = out.data().unwrap().to_vec();
let total: f32 = out_data.iter().sum();
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(TestSumBackward { input: out }),
)
.unwrap();
loss.backward().unwrap();
let grad = input.grad().unwrap().unwrap();
let g = grad.data().unwrap();
for (i, &val) in g.iter().enumerate() {
assert!(
(val - 1.0).abs() < 1e-5,
"grad[{i}]: expected 1.0, got {val}"
);
}
}
#[test]
fn test_unfold_shape() {
let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
assert_eq!(out.shape(), &[1, 4, 9]);
}
#[test]
fn test_unfold_values() {
#[rustfmt::skip]
let data: Vec<f32> = vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
];
let input = leaf_4d(&data, [1, 1, 3, 3], false);
let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
assert_eq!(out.shape(), &[1, 4, 4]);
let d = out.data().unwrap();
assert_close(&d[0..4], &[1.0, 2.0, 4.0, 5.0], 1e-6);
assert_close(&d[4..8], &[2.0, 3.0, 5.0, 6.0], 1e-6);
assert_close(&d[8..12], &[4.0, 5.0, 7.0, 8.0], 1e-6);
assert_close(&d[12..16], &[5.0, 6.0, 8.0, 9.0], 1e-6);
}
#[test]
fn test_unfold_with_padding() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 1, 2, 2], false);
let out = unfold(&input, [2, 2], [1, 1], [1, 1], [1, 1]).unwrap();
assert_eq!(out.shape(), &[1, 4, 9]);
}
#[test]
fn test_unfold_with_stride() {
let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
let out = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
assert_eq!(out.shape(), &[1, 4, 4]);
}
#[test]
fn test_unfold_zero_kernel_rejected() {
let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
assert!(unfold(&input, [0, 2], [1, 1], [0, 0], [1, 1]).is_err());
}
#[test]
fn test_fold_shape() {
let data = vec![0.0f32; 36];
let input = leaf(&data, &[1, 4, 9], false);
let out = fold(&input, [4, 4], [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
assert_eq!(out.shape(), &[1, 1, 4, 4]);
}
#[test]
fn test_unfold_fold_roundtrip() {
#[rustfmt::skip]
let data: Vec<f32> = vec![
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0,
9.0, 10.0, 11.0, 12.0,
13.0, 14.0, 15.0, 16.0,
];
let input = leaf_4d(&data, [1, 1, 4, 4], false);
let unfolded = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
let refolded = fold(&unfolded, [4, 4], [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
assert_eq!(refolded.shape(), &[1, 1, 4, 4]);
assert_close(refolded.data().unwrap(), &data, 1e-6);
}
#[test]
fn test_fold_l_mismatch() {
let data = vec![0.0f32; 20];
let input = leaf(&data, &[1, 4, 5], false);
assert!(fold(&input, [4, 4], [2, 2], [1, 1], [0, 0], [1, 1]).is_err());
}
#[test]
fn test_grid_sample_identity() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 1, 2, 2], false);
let grid_data: Vec<f32> = vec![-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0];
let grid = leaf(&grid_data, &[1, 2, 2, 2], false);
let out = grid_sample(
&input,
&grid,
GridSampleMode::Bilinear,
GridSamplePaddingMode::Zeros,
true,
)
.unwrap();
assert_eq!(out.shape(), &[1, 1, 2, 2]);
assert_close(out.data().unwrap(), &data, 1e-5);
}
#[test]
fn test_grid_sample_nearest() {
let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
let input = leaf_4d(&data, [1, 1, 2, 2], false);
let grid_data: Vec<f32> = vec![-1.0, -1.0];
let grid = leaf(&grid_data, &[1, 1, 1, 2], false);
let out = grid_sample(
&input,
&grid,
GridSampleMode::Nearest,
GridSamplePaddingMode::Zeros,
true,
)
.unwrap();
assert_eq!(out.shape(), &[1, 1, 1, 1]);
assert!((out.data().unwrap()[0] - 1.0).abs() < 1e-5);
}
#[test]
fn test_grid_sample_batch_mismatch() {
let input = leaf_4d(&[0.0; 8], [2, 1, 2, 2], false);
let grid = leaf(&[0.0; 8], &[1, 2, 2, 2], false);
assert!(
grid_sample(
&input,
&grid,
GridSampleMode::Bilinear,
GridSamplePaddingMode::Zeros,
true
)
.is_err()
);
}
#[test]
fn test_grid_sample_wrong_grid_shape() {
let input = leaf_4d(&[0.0; 4], [1, 1, 2, 2], false);
let grid = leaf(&[0.0; 8], &[1, 2, 4], false);
assert!(
grid_sample(
&input,
&grid,
GridSampleMode::Bilinear,
GridSamplePaddingMode::Zeros,
true
)
.is_err()
);
}
#[test]
fn test_affine_grid_identity() {
let theta_data: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let theta = leaf(&theta_data, &[1, 2, 3], false);
let grid = affine_grid(&theta, [1, 1, 3, 3], true).unwrap();
assert_eq!(grid.shape(), &[1, 3, 3, 2]);
let d = grid.data().unwrap();
assert!((d[0] - (-1.0)).abs() < 1e-5); assert!((d[1] - (-1.0)).abs() < 1e-5); assert!((d[4] - 1.0).abs() < 1e-5); assert!((d[5] - (-1.0)).abs() < 1e-5); }
#[test]
fn test_affine_grid_theta_shape_error() {
let theta = leaf(&[0.0; 12], &[2, 3, 2], false);
assert!(affine_grid(&theta, [2, 1, 3, 3], true).is_err());
}
#[test]
fn test_affine_grid_batch_mismatch() {
let theta = leaf(&[0.0; 6], &[1, 2, 3], false);
assert!(affine_grid(&theta, [2, 1, 3, 3], true).is_err());
}
#[test]
fn test_pixel_shuffle_module() {
let ps = PixelShuffle::new(2);
let input = leaf_4d(&[0.0; 16], [1, 4, 2, 2], false);
let out: Tensor<f32> = Module::<f32>::forward(&ps, &input).unwrap();
assert_eq!(out.shape(), &[1, 1, 4, 4]);
}
#[test]
fn test_pixel_unshuffle_module() {
let pus = PixelUnshuffle::new(2);
let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
let out: Tensor<f32> = Module::<f32>::forward(&pus, &input).unwrap();
assert_eq!(out.shape(), &[1, 4, 2, 2]);
}
#[test]
fn test_unfold_module() {
let uf = Unfold::new([2, 2], [1, 1], [0, 0], [1, 1]);
let input = leaf_4d(&[0.0; 16], [1, 1, 4, 4], false);
let out: Tensor<f32> = Module::<f32>::forward(&uf, &input).unwrap();
assert_eq!(out.shape(), &[1, 4, 9]);
}
#[test]
fn test_fold_module() {
let f = Fold::new([4, 4], [2, 2], [1, 1], [0, 0], [1, 1]);
let data = vec![0.0f32; 36];
let input = leaf(&data, &[1, 4, 9], false);
let out: Tensor<f32> = Module::<f32>::forward(&f, &input).unwrap();
assert_eq!(out.shape(), &[1, 1, 4, 4]);
}
#[test]
fn test_unfold_backward() {
let data: Vec<f32> = (0..16).map(|i| i as f32).collect();
let input = leaf_4d(&data, [1, 1, 4, 4], true);
let out = unfold(&input, [2, 2], [1, 1], [0, 0], [2, 2]).unwrap();
let out_data = out.data().unwrap().to_vec();
let total: f32 = out_data.iter().sum();
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(TestSumBackward { input: out }),
)
.unwrap();
loss.backward().unwrap();
let grad = input.grad().unwrap().unwrap();
let g = grad.data().unwrap();
for (i, &val) in g.iter().enumerate() {
assert!(
(val - 1.0).abs() < 1e-5,
"grad[{i}]: expected 1.0, got {val}"
);
}
}
#[test]
fn test_unfold_backward_overlapping() {
let data: Vec<f32> = (0..9).map(|i| i as f32).collect();
let input = leaf_4d(&data, [1, 1, 3, 3], true);
let out = unfold(&input, [2, 2], [1, 1], [0, 0], [1, 1]).unwrap();
let out_data = out.data().unwrap().to_vec();
let total: f32 = out_data.iter().sum();
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![total]),
vec![],
Arc::new(TestSumBackward { input: out }),
)
.unwrap();
loss.backward().unwrap();
let grad = input.grad().unwrap().unwrap();
let g = grad.data().unwrap();
#[rustfmt::skip]
let expected: Vec<f32> = vec![
1.0, 2.0, 1.0,
2.0, 4.0, 2.0,
1.0, 2.0, 1.0,
];
assert_close(g, &expected, 1e-5);
}
#[test]
fn test_interpolate_multichannel_batch() {
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let input = leaf_4d(&data, [2, 3, 2, 2], false);
let out = interpolate(&input, Some([4, 4]), None, InterpolateMode::Nearest, false).unwrap();
assert_eq!(out.shape(), &[2, 3, 4, 4]);
}
#[derive(Debug)]
struct TestSumBackward {
input: Tensor<f32>,
}
impl GradFn<f32> for TestSumBackward {
fn backward(
&self,
_grad_output: &Tensor<f32>,
) -> FerrotorchResult<Vec<Option<Tensor<f32>>>> {
let ones_data = vec![1.0f32; self.input.numel()];
let ones = Tensor::from_storage(
TensorStorage::cpu(ones_data),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(ones)])
}
fn inputs(&self) -> Vec<&Tensor<f32>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"TestSumBackward"
}
}
}