use alloc::vec;
use alloc::vec::Vec;
use burn_backend::element::cast::ToElement;
use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
use burn_backend::{DType, Element};
use burn_std::{Bytes, Shape, bf16, f16};
use num_traits::{Float, NumCast};
use crate::{FlexTensor, Layout};
pub fn grid_sample_2d(
tensor: FlexTensor,
grid: FlexTensor,
options: GridSampleOptions,
) -> FlexTensor {
match options.mode {
InterpolateMode::Bilinear | InterpolateMode::Nearest => {}
other => panic!("grid_sample_2d: {:?} mode is not supported", other),
}
let tensor = tensor.to_contiguous();
let grid = grid.to_contiguous();
match tensor.dtype() {
DType::F32 => grid_sample_2d_impl::<f32>(tensor, grid, options),
DType::F64 => grid_sample_2d_impl::<f64>(tensor, grid, options),
DType::F16 => grid_sample_2d_impl::<f16>(tensor, grid, options),
DType::BF16 => grid_sample_2d_impl::<bf16>(tensor, grid, options),
_ => panic!("grid_sample_2d: unsupported dtype {:?}", tensor.dtype()),
}
}
fn grid_sample_2d_impl<T>(
tensor: FlexTensor,
grid: FlexTensor,
options: GridSampleOptions,
) -> FlexTensor
where
T: Float + Element + bytemuck::Pod,
{
let t_shape = tensor.layout().shape();
let g_shape = grid.layout().shape();
assert_eq!(t_shape.num_dims(), 4, "grid_sample_2d: input must be 4D");
assert_eq!(g_shape.num_dims(), 4, "grid_sample_2d: grid must be 4D");
assert_eq!(g_shape[3], 2, "grid_sample_2d: grid last dim must be 2");
assert_eq!(
t_shape[0], g_shape[0],
"grid_sample_2d: batch size mismatch"
);
let batch_size = t_shape[0];
let channels = t_shape[1];
let h_in = t_shape[2];
let w_in = t_shape[3];
let h_out = g_shape[1];
let w_out = g_shape[2];
let tensor_data: &[T] = tensor.storage();
let grid_data: &[T] = grid.storage();
let out_shape = Shape::from(vec![batch_size, channels, h_out, w_out]);
let out_len = batch_size * channels * h_out * w_out;
let mut output: Vec<T> = vec![T::zero(); out_len];
let align = options.align_corners;
let pad_mode = options.padding_mode;
let t_stride_n = channels * h_in * w_in;
let t_stride_c = h_in * w_in;
let t_stride_h = w_in;
let g_stride_n = h_out * w_out * 2;
let g_stride_h = w_out * 2;
let o_stride_n = channels * h_out * w_out;
let o_stride_c = h_out * w_out;
let o_stride_h = w_out;
let to_f64 = |x: T| -> f64 { ToElement::to_f64(&x) };
let from_f64 = |x: f64| -> T {
<T as NumCast>::from(x).unwrap_or_else(|| {
panic!(
"grid_sample_2d: NumCast::from({x:?}) to {:?} returned None",
T::dtype()
)
})
};
for b in 0..batch_size {
for y in 0..h_out {
for x in 0..w_out {
let g_idx = b * g_stride_n + y * g_stride_h + x * 2;
let sample_x = to_f64(grid_data[g_idx]);
let sample_y = to_f64(grid_data[g_idx + 1]);
let (px, py) = if align {
let px = (sample_x + 1.0) * ((w_in - 1) as f64) / 2.0;
let py = (sample_y + 1.0) * ((h_in - 1) as f64) / 2.0;
(px, py)
} else {
let px = (sample_x + 1.0) * (w_in as f64) / 2.0 - 0.5;
let py = (sample_y + 1.0) * (h_in as f64) / 2.0 - 0.5;
(px, py)
};
let (px, py) = apply_padding(px, py, w_in, h_in, pad_mode, align);
let read = |t_base: usize, xi: i64, yi: i64| -> f64 {
match pad_mode {
GridSamplePaddingMode::Zeros => {
if xi >= 0 && xi < w_in as i64 && yi >= 0 && yi < h_in as i64 {
to_f64(tensor_data[t_base + yi as usize * t_stride_h + xi as usize])
} else {
0.0
}
}
GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => {
let xi = xi.clamp(0, (w_in - 1) as i64) as usize;
let yi = yi.clamp(0, (h_in - 1) as i64) as usize;
to_f64(tensor_data[t_base + yi * t_stride_h + xi])
}
}
};
for c in 0..channels {
let t_base = b * t_stride_n + c * t_stride_c;
let o_idx = b * o_stride_n + c * o_stride_c + y * o_stride_h + x;
let val = if matches!(options.mode, InterpolateMode::Nearest) {
let xi = px.round() as i64;
let yi = py.round() as i64;
read(t_base, xi, yi)
} else {
let x0 = px.floor() as i64;
let y0 = py.floor() as i64;
let x1 = x0 + 1;
let y1 = y0 + 1;
let x_frac = px - px.floor();
let y_frac = py - py.floor();
let w00 = (1.0 - x_frac) * (1.0 - y_frac);
let w01 = (1.0 - x_frac) * y_frac;
let w10 = x_frac * (1.0 - y_frac);
let w11 = x_frac * y_frac;
read(t_base, x0, y0) * w00
+ read(t_base, x0, y1) * w01
+ read(t_base, x1, y0) * w10
+ read(t_base, x1, y1) * w11
};
output[o_idx] = from_f64(val);
}
}
}
}
let bytes = Bytes::from_elems(output);
FlexTensor::new(bytes, Layout::contiguous(out_shape), T::dtype())
}
fn apply_padding(
px: f64,
py: f64,
w: usize,
h: usize,
mode: GridSamplePaddingMode,
align_corners: bool,
) -> (f64, f64) {
if !px.is_finite() || !py.is_finite() {
return match mode {
GridSamplePaddingMode::Border => {
let cx = ((w - 1) as f64 / 2.0).clamp(0.0, (w - 1) as f64);
let cy = ((h - 1) as f64 / 2.0).clamp(0.0, (h - 1) as f64);
(cx, cy)
}
_ => (px, py),
};
}
match mode {
GridSamplePaddingMode::Zeros => (px, py),
GridSamplePaddingMode::Border => {
let px = px.clamp(0.0, (w - 1) as f64);
let py = py.clamp(0.0, (h - 1) as f64);
(px, py)
}
GridSamplePaddingMode::Reflection => {
let px = reflect_coordinate(px, w, align_corners);
let py = reflect_coordinate(py, h, align_corners);
(px, py)
}
}
}
fn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 {
let size_f = size as f64;
let (min_val, max_val) = if align_corners {
(0.0, size_f - 1.0)
} else {
(-0.5, size_f - 0.5)
};
let span = max_val - min_val;
if span <= 0.0 {
return min_val;
}
let period = 2.0 * span;
let x = (coord - min_val).abs();
let x_mod = x - (x / period).floor() * period;
span - (x_mod - span).abs() + min_val
}