use cubecl::prelude::*;
use crate::{CubeRuntime, tensor::CubeTensor};
use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
use super::bilinear::grid_sample_bilinear_launch;
pub fn grid_sample<R: CubeRuntime>(
input: CubeTensor<R>,
grid: CubeTensor<R>,
options: GridSampleOptions,
) -> CubeTensor<R> {
match options.mode {
InterpolateMode::Bilinear => grid_sample_bilinear_launch(input, grid, options),
_ => panic!(
"Unsupported grid_sample interpolation mode: {:?}",
options.mode
),
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum PaddingMode {
Zeros,
Border,
Reflection,
}
impl From<GridSamplePaddingMode> for PaddingMode {
fn from(mode: GridSamplePaddingMode) -> Self {
match mode {
GridSamplePaddingMode::Zeros => PaddingMode::Zeros,
GridSamplePaddingMode::Border => PaddingMode::Border,
GridSamplePaddingMode::Reflection => PaddingMode::Reflection,
}
}
}
#[cube]
pub(crate) fn fetch_value<F: Float>(
input: &Tensor<F>,
base: usize,
stride_h: usize,
stride_w: usize,
y: i32,
x: i32,
h: i32,
w: i32,
#[comptime] padding_mode: PaddingMode,
) -> F {
match padding_mode {
PaddingMode::Zeros => fetch_with_zeros(input, base, stride_h, stride_w, y, x, h, w),
PaddingMode::Border => fetch_with_border(input, base, stride_h, stride_w, y, x, h, w),
PaddingMode::Reflection => {
fetch_with_reflection(input, base, stride_h, stride_w, y, x, h, w)
}
}
}
#[cube]
pub(crate) fn fetch_with_zeros<F: Float>(
input: &Tensor<F>,
base: usize,
stride_h: usize,
stride_w: usize,
y: i32,
x: i32,
h: i32,
w: i32,
) -> F {
let in_bounds = x >= 0 && x < w && y >= 0 && y < h;
let x_clamped = clamp(x, 0, w - 1) as usize;
let y_clamped = clamp(y, 0, h - 1) as usize;
let idx = base + y_clamped * stride_h + x_clamped * stride_w;
select(in_bounds, input[idx], F::new(0.0))
}
#[cube]
pub(crate) fn fetch_with_border<F: Float>(
input: &Tensor<F>,
base: usize,
stride_h: usize,
stride_w: usize,
y: i32,
x: i32,
h: i32,
w: i32,
) -> F {
let x_clamped = clamp(x, 0, w - 1) as usize;
let y_clamped = clamp(y, 0, h - 1) as usize;
let idx = base + y_clamped * stride_h + x_clamped * stride_w;
input[idx]
}
#[cube]
pub(crate) fn fetch_with_reflection<F: Float>(
input: &Tensor<F>,
base: usize,
stride_h: usize,
stride_w: usize,
y: i32,
x: i32,
h: i32,
w: i32,
) -> F {
let x_reflected = reflect_coord_bounded(x, w);
let y_reflected = reflect_coord_bounded(y, h);
let idx = base + y_reflected * stride_h + x_reflected * stride_w;
input[idx]
}
#[cube]
fn reflect_coord_bounded(idx: i32, size: i32) -> usize {
let max_idx = size - 1;
let neg_reflected = -idx - 1;
let pos_reflected = 2 * max_idx + 1 - idx;
let result = select(
idx < 0,
neg_reflected,
select(idx > max_idx, pos_reflected, idx),
);
clamp(result, 0, max_idx) as usize
}
#[cube]
pub(crate) fn reflect_coord<F: Float>(coord: F, size: u32, #[comptime] align_corners: bool) -> F {
let size_f = F::cast_from(size);
if align_corners {
reflect_float_impl::<F>(coord, F::new(0.0), size_f - F::new(1.0))
} else {
reflect_float_impl::<F>(coord, F::new(-0.5), size_f - F::new(0.5))
}
}
#[cube]
fn reflect_float_impl<F: Float>(coord: F, min_val: F, max_val: F) -> F {
let span = max_val - min_val;
let is_valid = span > F::new(0.0);
let safe_span = select(is_valid, span, F::new(1.0));
let period = safe_span * F::new(2.0);
let x = (coord - min_val).abs();
let x_mod = x - (x / period).floor() * period;
let reflected = safe_span - (x_mod - safe_span).abs() + min_val;
select(is_valid, reflected, min_val)
}