1use crate::{
2 Backend, TensorMetadata, get_device_settings,
3 ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},
4 tensor::FloatTensor,
5};
6use alloc::vec;
7use burn_std::{Shape, Slice};
8
9pub fn float_grid_sample_2d_ref<B: Backend>(
22 tensor: FloatTensor<B>,
23 grid: FloatTensor<B>,
24 options: GridSampleOptions,
25) -> FloatTensor<B> {
26 match options.mode {
27 InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>(
28 tensor,
29 grid,
30 options.padding_mode,
31 options.align_corners,
32 ),
33 _ => todo!(
34 "Default implementation for grid_sample_2d with {:?} unimplemented",
35 options.mode
36 ),
37 }
38}
39
40fn float_grid_sample_2d_bilinear<B: Backend>(
42 tensor: FloatTensor<B>,
43 grid: FloatTensor<B>,
44 padding_mode: GridSamplePaddingMode,
45 align_corners: bool,
46) -> FloatTensor<B> {
47 let n = tensor.shape()[0];
48 let c = tensor.shape()[1];
49 let h_in = tensor.shape()[2];
50 let w_in = tensor.shape()[3];
51 let h_out = grid.shape()[1];
52 let w_out = grid.shape()[2];
53 let spatial_in = h_in * w_in;
54 let spatial_out = h_out * w_out;
55 let device = B::float_device(&tensor);
56
57 let grid_x_slice = vec![
60 Slice::new(0, Some(n as isize), 1),
61 Slice::new(0, Some(h_out as isize), 1),
62 Slice::new(0, Some(w_out as isize), 1),
63 Slice::new(0, Some(1), 1),
64 ];
65 let grid_y_slice = vec![
66 Slice::new(0, Some(n as isize), 1),
67 Slice::new(0, Some(h_out as isize), 1),
68 Slice::new(0, Some(w_out as isize), 1),
69 Slice::new(1, Some(2), 1),
70 ];
71
72 let grid_x = B::float_slice(grid.clone(), &grid_x_slice);
73 let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out]));
74 let grid_y = B::float_slice(grid.clone(), &grid_y_slice);
75 let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out]));
76
77 let w_in_f = w_in as f64;
79 let h_in_f = h_in as f64;
80
81 let (grid_x, grid_y) = if align_corners {
82 let grid_x = B::float_add_scalar(grid_x, 1f32.into());
85 let grid_x = B::float_mul_scalar(grid_x, ((w_in_f - 1.0) / 2.0).into());
86
87 let grid_y = B::float_add_scalar(grid_y, 1f32.into());
88 let grid_y = B::float_mul_scalar(grid_y, ((h_in_f - 1.0) / 2.0).into());
89
90 (grid_x, grid_y)
91 } else {
92 let grid_x = B::float_add_scalar(grid_x, 1f32.into());
95 let grid_x = B::float_mul_scalar(grid_x, (w_in_f / 2.0).into());
96 let grid_x = B::float_sub_scalar(grid_x, 0.5f32.into());
97
98 let grid_y = B::float_add_scalar(grid_y, 1f32.into());
99 let grid_y = B::float_mul_scalar(grid_y, (h_in_f / 2.0).into());
100 let grid_y = B::float_sub_scalar(grid_y, 0.5f32.into());
101
102 (grid_x, grid_y)
103 };
104
105 let (grid_x, grid_y) = match padding_mode {
107 GridSamplePaddingMode::Border => {
108 let grid_x = B::float_clamp(grid_x, 0f32.into(), ((w_in - 1) as f32).into());
110 let grid_y = B::float_clamp(grid_y, 0f32.into(), ((h_in - 1) as f32).into());
111 (grid_x, grid_y)
112 }
113 GridSamplePaddingMode::Reflection => {
114 let grid_x = reflect_coordinates::<B>(grid_x, w_in_f, align_corners);
116 let grid_y = reflect_coordinates::<B>(grid_y, h_in_f, align_corners);
117 (grid_x, grid_y)
118 }
119 GridSamplePaddingMode::Zeros => {
120 (grid_x, grid_y)
122 }
123 };
124
125 let grid_x_floored = B::float_floor(grid_x.clone());
127 let grid_y_floored = B::float_floor(grid_y.clone());
128
129 let x_frac = B::float_sub(grid_x.clone(), grid_x_floored.clone());
131 let y_frac = B::float_sub(grid_y.clone(), grid_y_floored.clone());
132
133 let settings = get_device_settings::<B>(&device);
135 let x0 = B::float_into_int(grid_x_floored.clone(), settings.int_dtype);
136 let y0 = B::float_into_int(grid_y_floored.clone(), settings.int_dtype);
137 let x1 = B::float_into_int(
138 B::float_add_scalar(grid_x_floored, 1f32.into()),
139 settings.int_dtype,
140 );
141 let y1 = B::float_into_int(
142 B::float_add_scalar(grid_y_floored, 1f32.into()),
143 settings.int_dtype,
144 );
145
146 let (mask_00, mask_01, mask_10, mask_11) = if padding_mode == GridSamplePaddingMode::Zeros {
148 let x0_valid = B::int_greater_equal_elem(x0.clone(), 0.into(), settings.bool_dtype);
149 let x0_valid = B::bool_and(
150 x0_valid,
151 B::int_lower_elem(x0.clone(), (w_in as i32).into(), settings.bool_dtype),
152 );
153 let x1_valid = B::int_greater_equal_elem(x1.clone(), 0.into(), settings.bool_dtype);
154 let x1_valid = B::bool_and(
155 x1_valid,
156 B::int_lower_elem(x1.clone(), (w_in as i32).into(), settings.bool_dtype),
157 );
158 let y0_valid = B::int_greater_equal_elem(y0.clone(), 0.into(), settings.bool_dtype);
159 let y0_valid = B::bool_and(
160 y0_valid,
161 B::int_lower_elem(y0.clone(), (h_in as i32).into(), settings.bool_dtype),
162 );
163 let y1_valid = B::int_greater_equal_elem(y1.clone(), 0.into(), settings.bool_dtype);
164 let y1_valid = B::bool_and(
165 y1_valid,
166 B::int_lower_elem(y1.clone(), (h_in as i32).into(), settings.bool_dtype),
167 );
168
169 (
170 Some(B::bool_and(x0_valid.clone(), y0_valid.clone())),
171 Some(B::bool_and(x0_valid.clone(), y1_valid.clone())),
172 Some(B::bool_and(x1_valid.clone(), y0_valid)),
173 Some(B::bool_and(x1_valid, y1_valid)),
174 )
175 } else {
176 (None, None, None, None)
177 };
178
179 let x0_clamped = B::int_clamp(x0, 0.into(), ((w_in - 1) as i32).into());
181 let x1_clamped = B::int_clamp(x1, 0.into(), ((w_in - 1) as i32).into());
182 let y0_clamped = B::int_clamp(y0, 0.into(), ((h_in - 1) as i32).into());
183 let y1_clamped = B::int_clamp(y1, 0.into(), ((h_in - 1) as i32).into());
184
185 let w_in_scalar: i32 = w_in as i32;
187 let idx_00 = B::int_add(
188 B::int_mul_scalar(y0_clamped.clone(), w_in_scalar.into()),
189 x0_clamped.clone(),
190 );
191 let idx_01 = B::int_add(
192 B::int_mul_scalar(y1_clamped.clone(), w_in_scalar.into()),
193 x0_clamped,
194 );
195 let idx_10 = B::int_add(
196 B::int_mul_scalar(y0_clamped, w_in_scalar.into()),
197 x1_clamped.clone(),
198 );
199 let idx_11 = B::int_add(
200 B::int_mul_scalar(y1_clamped, w_in_scalar.into()),
201 x1_clamped,
202 );
203
204 let idx_00 = B::int_reshape(idx_00, Shape::new([n, 1, spatial_out]));
206 let idx_01 = B::int_reshape(idx_01, Shape::new([n, 1, spatial_out]));
207 let idx_10 = B::int_reshape(idx_10, Shape::new([n, 1, spatial_out]));
208 let idx_11 = B::int_reshape(idx_11, Shape::new([n, 1, spatial_out]));
209
210 let idx_00 = B::int_expand(idx_00, Shape::new([n, c, spatial_out]));
212 let idx_01 = B::int_expand(idx_01, Shape::new([n, c, spatial_out]));
213 let idx_10 = B::int_expand(idx_10, Shape::new([n, c, spatial_out]));
214 let idx_11 = B::int_expand(idx_11, Shape::new([n, c, spatial_out]));
215
216 let tensor_flat = B::float_reshape(tensor, Shape::new([n, c, spatial_in]));
217
218 let sample_00 = B::float_gather(2, tensor_flat.clone(), idx_00);
219 let sample_01 = B::float_gather(2, tensor_flat.clone(), idx_01);
220 let sample_10 = B::float_gather(2, tensor_flat.clone(), idx_10);
221 let sample_11 = B::float_gather(2, tensor_flat, idx_11);
222
223 let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out]));
225 let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out]));
226 let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out]));
227 let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out]));
228
229 let (sample_00, sample_01, sample_10, sample_11) =
231 if padding_mode == GridSamplePaddingMode::Zeros {
232 let mask_00 = mask_00.unwrap();
233 let mask_01 = mask_01.unwrap();
234 let mask_10 = mask_10.unwrap();
235 let mask_11 = mask_11.unwrap();
236
237 let mask_00_inv = B::bool_not(mask_00);
238 let mask_00_inv = B::bool_reshape(mask_00_inv, Shape::new([n, 1, h_out, w_out]));
239 let mask_00_inv = B::bool_expand(mask_00_inv, Shape::new([n, c, h_out, w_out]));
240 let mask_01_inv = B::bool_not(mask_01);
241 let mask_01_inv = B::bool_reshape(mask_01_inv, Shape::new([n, 1, h_out, w_out]));
242 let mask_01_inv = B::bool_expand(mask_01_inv, Shape::new([n, c, h_out, w_out]));
243 let mask_10_inv = B::bool_not(mask_10);
244 let mask_10_inv = B::bool_reshape(mask_10_inv, Shape::new([n, 1, h_out, w_out]));
245 let mask_10_inv = B::bool_expand(mask_10_inv, Shape::new([n, c, h_out, w_out]));
246 let mask_11_inv = B::bool_not(mask_11);
247 let mask_11_inv = B::bool_reshape(mask_11_inv, Shape::new([n, 1, h_out, w_out]));
248 let mask_11_inv = B::bool_expand(mask_11_inv, Shape::new([n, c, h_out, w_out]));
249
250 (
251 B::float_mask_fill(sample_00, mask_00_inv, 0f32.into()),
252 B::float_mask_fill(sample_01, mask_01_inv, 0f32.into()),
253 B::float_mask_fill(sample_10, mask_10_inv, 0f32.into()),
254 B::float_mask_fill(sample_11, mask_11_inv, 0f32.into()),
255 )
256 } else {
257 (sample_00, sample_01, sample_10, sample_11)
258 };
259
260 let one_minus_x = B::float_neg(x_frac.clone());
262 let one_minus_x = B::float_add_scalar(one_minus_x, 1f32.into());
263
264 let one_minus_y = B::float_neg(y_frac.clone());
265 let one_minus_y = B::float_add_scalar(one_minus_y, 1f32.into());
266
267 let weight_00 = B::float_mul(one_minus_x.clone(), one_minus_y.clone());
268 let weight_01 = B::float_mul(one_minus_x.clone(), y_frac.clone());
269 let weight_10 = B::float_mul(x_frac.clone(), one_minus_y);
270 let weight_11 = B::float_mul(x_frac, y_frac);
271
272 let result = B::float_mul(sample_00, weight_00);
274 let result = B::float_add(result, B::float_mul(sample_01, weight_01));
275 let result = B::float_add(result, B::float_mul(sample_10, weight_10));
276
277 B::float_add(result, B::float_mul(sample_11, weight_11))
278}
279
280fn reflect_coordinates<B: Backend>(
285 coords: FloatTensor<B>,
286 size: f64,
287 align_corners: bool,
288) -> FloatTensor<B> {
289 let (min_val, max_val) = if align_corners {
290 (0.0f32, (size - 1.0) as f32)
291 } else {
292 (-0.5f32, (size - 0.5) as f32)
293 };
294
295 let span = max_val - min_val;
296 if span <= 0.0 {
297 let zeros = B::float_mul_scalar(coords, 0f32.into());
299 return B::float_add_scalar(zeros, min_val.into());
300 }
301
302 let period = 2.0 * span;
304
305 let x = B::float_sub_scalar(coords, min_val.into());
307 let x = B::float_abs(x);
308
309 let x_div = B::float_div_scalar(x.clone(), period.into());
311 let x_div_floor = B::float_floor(x_div);
312 let x_mod = B::float_sub(x, B::float_mul_scalar(x_div_floor, period.into()));
313
314 let diff = B::float_sub_scalar(x_mod, span.into());
316 let abs_diff = B::float_abs(diff);
317 let reflected = B::float_sub_scalar(abs_diff, span.into());
318 let reflected = B::float_neg(reflected);
319 B::float_add_scalar(reflected, min_val.into())
320}