burn_cubecl/kernel/grid_sample/
base.rs

1use cubecl::prelude::*;
2
3use crate::{CubeRuntime, tensor::CubeTensor};
4use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
5
6use super::bilinear::grid_sample_bilinear_launch;
7
8/// Grid sample operation supporting bilinear interpolation
9pub fn grid_sample<R: CubeRuntime>(
10    input: CubeTensor<R>,
11    grid: CubeTensor<R>,
12    options: GridSampleOptions,
13) -> CubeTensor<R> {
14    match options.mode {
15        InterpolateMode::Bilinear => grid_sample_bilinear_launch(input, grid, options),
16        _ => panic!(
17            "Unsupported grid_sample interpolation mode: {:?}",
18            options.mode
19        ),
20    }
21}
22
23/// Compile-time padding mode for kernel specialization
24#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
25pub enum PaddingMode {
26    /// Fill with zeros for out-of-bounds coordinates.
27    Zeros,
28    /// Clamp coordinates to the border (use nearest edge value).
29    Border,
30    /// Reflect coordinates at the boundary.
31    Reflection,
32}
33
34impl From<GridSamplePaddingMode> for PaddingMode {
35    fn from(mode: GridSamplePaddingMode) -> Self {
36        match mode {
37            GridSamplePaddingMode::Zeros => PaddingMode::Zeros,
38            GridSamplePaddingMode::Border => PaddingMode::Border,
39            GridSamplePaddingMode::Reflection => PaddingMode::Reflection,
40        }
41    }
42}
43
44/// Fetch value based on padding mode (dispatch to appropriate handler)
45#[cube]
46pub(crate) fn fetch_value<F: Float>(
47    input: &Tensor<F>,
48    base: usize,
49    stride_h: usize,
50    stride_w: usize,
51    y: i32,
52    x: i32,
53    h: i32,
54    w: i32,
55    #[comptime] padding_mode: PaddingMode,
56) -> F {
57    match padding_mode {
58        PaddingMode::Zeros => fetch_with_zeros(input, base, stride_h, stride_w, y, x, h, w),
59        PaddingMode::Border => fetch_with_border(input, base, stride_h, stride_w, y, x, h, w),
60        PaddingMode::Reflection => {
61            fetch_with_reflection(input, base, stride_h, stride_w, y, x, h, w)
62        }
63    }
64}
65
66/// Fetch value with zeros padding (return 0 for out-of-bounds).
67#[cube]
68pub(crate) fn fetch_with_zeros<F: Float>(
69    input: &Tensor<F>,
70    base: usize,
71    stride_h: usize,
72    stride_w: usize,
73    y: i32,
74    x: i32,
75    h: i32,
76    w: i32,
77) -> F {
78    let in_bounds = x >= 0 && x < w && y >= 0 && y < h;
79    let x_clamped = clamp(x, 0, w - 1) as usize;
80    let y_clamped = clamp(y, 0, h - 1) as usize;
81    let idx = base + y_clamped * stride_h + x_clamped * stride_w;
82    select(in_bounds, input[idx], F::new(0.0))
83}
84
85/// Fetch value with border padding (clamp to edge).
86#[cube]
87pub(crate) fn fetch_with_border<F: Float>(
88    input: &Tensor<F>,
89    base: usize,
90    stride_h: usize,
91    stride_w: usize,
92    y: i32,
93    x: i32,
94    h: i32,
95    w: i32,
96) -> F {
97    let x_clamped = clamp(x, 0, w - 1) as usize;
98    let y_clamped = clamp(y, 0, h - 1) as usize;
99    let idx = base + y_clamped * stride_h + x_clamped * stride_w;
100    input[idx]
101}
102
103/// Fetch value with reflection padding.
104/// Assumes float reflection was applied to center, so indices are at most 2 steps out of bounds.
105#[cube]
106pub(crate) fn fetch_with_reflection<F: Float>(
107    input: &Tensor<F>,
108    base: usize,
109    stride_h: usize,
110    stride_w: usize,
111    y: i32,
112    x: i32,
113    h: i32,
114    w: i32,
115) -> F {
116    let x_reflected = reflect_coord_bounded(x, w);
117    let y_reflected = reflect_coord_bounded(y, h);
118    let idx = base + y_reflected * stride_h + x_reflected * stride_w;
119    input[idx]
120}
121
122/// Reflect an integer index that may be out of bounds.
123/// After float reflection, indices can be up to 2 steps out for bicubic (1 step for bilinear).
124#[cube]
125fn reflect_coord_bounded(idx: i32, size: i32) -> usize {
126    let max_idx = size - 1;
127    let neg_reflected = -idx - 1;
128    let pos_reflected = 2 * max_idx + 1 - idx;
129    let result = select(
130        idx < 0,
131        neg_reflected,
132        select(idx > max_idx, pos_reflected, idx),
133    );
134    clamp(result, 0, max_idx) as usize
135}
136
137/// Reflect a float coordinate into the valid sampling range.
138#[cube]
139pub(crate) fn reflect_coord<F: Float>(coord: F, size: u32, #[comptime] align_corners: bool) -> F {
140    let size_f = F::cast_from(size);
141    if align_corners {
142        reflect_float_impl::<F>(coord, F::new(0.0), size_f - F::new(1.0))
143    } else {
144        reflect_float_impl::<F>(coord, F::new(-0.5), size_f - F::new(0.5))
145    }
146}
147
148/// Reflect a float coordinate into [min_val, max_val] using a triangle wave pattern.
149#[cube]
150fn reflect_float_impl<F: Float>(coord: F, min_val: F, max_val: F) -> F {
151    let span = max_val - min_val;
152
153    let is_valid = span > F::new(0.0);
154    let safe_span = select(is_valid, span, F::new(1.0));
155
156    // Triangle wave formula: span - |((x mod 2*span) - span)| + min_val
157    let period = safe_span * F::new(2.0);
158    let x = (coord - min_val).abs();
159    let x_mod = x - (x / period).floor() * period;
160    let reflected = safe_span - (x_mod - safe_span).abs() + min_val;
161
162    select(is_valid, reflected, min_val)
163}