burn_cubecl/kernel/grid_sample/
base.rs1use cubecl::prelude::*;
2
3use crate::{CubeRuntime, tensor::CubeTensor};
4use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
5
6use super::bilinear::grid_sample_bilinear_launch;
7
8pub fn grid_sample<R: CubeRuntime>(
10 input: CubeTensor<R>,
11 grid: CubeTensor<R>,
12 options: GridSampleOptions,
13) -> CubeTensor<R> {
14 match options.mode {
15 InterpolateMode::Bilinear => grid_sample_bilinear_launch(input, grid, options),
16 _ => panic!(
17 "Unsupported grid_sample interpolation mode: {:?}",
18 options.mode
19 ),
20 }
21}
22
23#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
25pub enum PaddingMode {
26 Zeros,
28 Border,
30 Reflection,
32}
33
34impl From<GridSamplePaddingMode> for PaddingMode {
35 fn from(mode: GridSamplePaddingMode) -> Self {
36 match mode {
37 GridSamplePaddingMode::Zeros => PaddingMode::Zeros,
38 GridSamplePaddingMode::Border => PaddingMode::Border,
39 GridSamplePaddingMode::Reflection => PaddingMode::Reflection,
40 }
41 }
42}
43
44#[cube]
46pub(crate) fn fetch_value<F: Float>(
47 input: &Tensor<F>,
48 base: usize,
49 stride_h: usize,
50 stride_w: usize,
51 y: i32,
52 x: i32,
53 h: i32,
54 w: i32,
55 #[comptime] padding_mode: PaddingMode,
56) -> F {
57 match padding_mode {
58 PaddingMode::Zeros => fetch_with_zeros(input, base, stride_h, stride_w, y, x, h, w),
59 PaddingMode::Border => fetch_with_border(input, base, stride_h, stride_w, y, x, h, w),
60 PaddingMode::Reflection => {
61 fetch_with_reflection(input, base, stride_h, stride_w, y, x, h, w)
62 }
63 }
64}
65
66#[cube]
68pub(crate) fn fetch_with_zeros<F: Float>(
69 input: &Tensor<F>,
70 base: usize,
71 stride_h: usize,
72 stride_w: usize,
73 y: i32,
74 x: i32,
75 h: i32,
76 w: i32,
77) -> F {
78 let in_bounds = x >= 0 && x < w && y >= 0 && y < h;
79 let x_clamped = clamp(x, 0, w - 1) as usize;
80 let y_clamped = clamp(y, 0, h - 1) as usize;
81 let idx = base + y_clamped * stride_h + x_clamped * stride_w;
82 select(in_bounds, input[idx], F::new(0.0))
83}
84
85#[cube]
87pub(crate) fn fetch_with_border<F: Float>(
88 input: &Tensor<F>,
89 base: usize,
90 stride_h: usize,
91 stride_w: usize,
92 y: i32,
93 x: i32,
94 h: i32,
95 w: i32,
96) -> F {
97 let x_clamped = clamp(x, 0, w - 1) as usize;
98 let y_clamped = clamp(y, 0, h - 1) as usize;
99 let idx = base + y_clamped * stride_h + x_clamped * stride_w;
100 input[idx]
101}
102
103#[cube]
106pub(crate) fn fetch_with_reflection<F: Float>(
107 input: &Tensor<F>,
108 base: usize,
109 stride_h: usize,
110 stride_w: usize,
111 y: i32,
112 x: i32,
113 h: i32,
114 w: i32,
115) -> F {
116 let x_reflected = reflect_coord_bounded(x, w);
117 let y_reflected = reflect_coord_bounded(y, h);
118 let idx = base + y_reflected * stride_h + x_reflected * stride_w;
119 input[idx]
120}
121
122#[cube]
125fn reflect_coord_bounded(idx: i32, size: i32) -> usize {
126 let max_idx = size - 1;
127 let neg_reflected = -idx - 1;
128 let pos_reflected = 2 * max_idx + 1 - idx;
129 let result = select(
130 idx < 0,
131 neg_reflected,
132 select(idx > max_idx, pos_reflected, idx),
133 );
134 clamp(result, 0, max_idx) as usize
135}
136
137#[cube]
139pub(crate) fn reflect_coord<F: Float>(coord: F, size: u32, #[comptime] align_corners: bool) -> F {
140 let size_f = F::cast_from(size);
141 if align_corners {
142 reflect_float_impl::<F>(coord, F::new(0.0), size_f - F::new(1.0))
143 } else {
144 reflect_float_impl::<F>(coord, F::new(-0.5), size_f - F::new(0.5))
145 }
146}
147
148#[cube]
150fn reflect_float_impl<F: Float>(coord: F, min_val: F, max_val: F) -> F {
151 let span = max_val - min_val;
152
153 let is_valid = span > F::new(0.0);
154 let safe_span = select(is_valid, span, F::new(1.0));
155
156 let period = safe_span * F::new(2.0);
158 let x = (coord - min_val).abs();
159 let x_mod = x - (x / period).floor() * period;
160 let reflected = safe_span - (x_mod - safe_span).abs() + min_val;
161
162 select(is_valid, reflected, min_val)
163}