1use crate::{
2 Backend, TensorMetadata,
3 element::ElementConversion,
4 ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},
5 tensor::FloatTensor,
6};
7use alloc::vec;
8use burn_std::{Shape, Slice};
9
10pub fn float_grid_sample_2d_ref<B: Backend>(
23 tensor: FloatTensor<B>,
24 grid: FloatTensor<B>,
25 options: GridSampleOptions,
26) -> FloatTensor<B> {
27 match options.mode {
28 InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>(
29 tensor,
30 grid,
31 options.padding_mode,
32 options.align_corners,
33 ),
34 _ => todo!(
35 "Default implementation for grid_sample_2d with {:?} unimplemented",
36 options.mode
37 ),
38 }
39}
40
41fn float_grid_sample_2d_bilinear<B: Backend>(
43 tensor: FloatTensor<B>,
44 grid: FloatTensor<B>,
45 padding_mode: GridSamplePaddingMode,
46 align_corners: bool,
47) -> FloatTensor<B> {
48 let n = tensor.shape().dims[0];
49 let c = tensor.shape().dims[1];
50 let h_in = tensor.shape().dims[2];
51 let w_in = tensor.shape().dims[3];
52 let h_out = grid.shape().dims[1];
53 let w_out = grid.shape().dims[2];
54
55 let grid_x_slice = vec![
58 Slice::new(0, Some(n as isize), 1),
59 Slice::new(0, Some(h_out as isize), 1),
60 Slice::new(0, Some(w_out as isize), 1),
61 Slice::new(0, Some(1), 1),
62 ];
63 let grid_y_slice = vec![
64 Slice::new(0, Some(n as isize), 1),
65 Slice::new(0, Some(h_out as isize), 1),
66 Slice::new(0, Some(w_out as isize), 1),
67 Slice::new(1, Some(2), 1),
68 ];
69
70 let grid_x = B::float_slice(grid.clone(), &grid_x_slice);
71 let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out]));
72 let grid_y = B::float_slice(grid.clone(), &grid_y_slice);
73 let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out]));
74
75 let w_in_f = w_in as f64;
77 let h_in_f = h_in as f64;
78
79 let (grid_x, grid_y) = if align_corners {
80 let grid_x = B::float_add_scalar(grid_x, 1.0f32.elem());
83 let grid_x = B::float_mul_scalar(grid_x, ((w_in_f - 1.0) / 2.0).elem());
84
85 let grid_y = B::float_add_scalar(grid_y, 1.0f32.elem());
86 let grid_y = B::float_mul_scalar(grid_y, ((h_in_f - 1.0) / 2.0).elem());
87
88 (grid_x, grid_y)
89 } else {
90 let grid_x = B::float_add_scalar(grid_x, 1.0f32.elem());
93 let grid_x = B::float_mul_scalar(grid_x, (w_in_f / 2.0).elem());
94 let grid_x = B::float_sub_scalar(grid_x, 0.5f32.elem());
95
96 let grid_y = B::float_add_scalar(grid_y, 1.0f32.elem());
97 let grid_y = B::float_mul_scalar(grid_y, (h_in_f / 2.0).elem());
98 let grid_y = B::float_sub_scalar(grid_y, 0.5f32.elem());
99
100 (grid_x, grid_y)
101 };
102
103 let (grid_x, grid_y) = match padding_mode {
105 GridSamplePaddingMode::Border => {
106 let grid_x = B::float_clamp(grid_x, 0.0f32.elem(), ((w_in - 1) as f32).elem());
108 let grid_y = B::float_clamp(grid_y, 0.0f32.elem(), ((h_in - 1) as f32).elem());
109 (grid_x, grid_y)
110 }
111 GridSamplePaddingMode::Reflection => {
112 let grid_x = reflect_coordinates::<B>(grid_x, w_in_f);
115 let grid_y = reflect_coordinates::<B>(grid_y, h_in_f);
116 (grid_x, grid_y)
117 }
118 GridSamplePaddingMode::Zeros => {
119 (grid_x, grid_y)
121 }
122 };
123
124 let grid_x_floored = B::float_floor(grid_x.clone());
126 let grid_y_floored = B::float_floor(grid_y.clone());
127
128 let x_frac = B::float_sub(grid_x.clone(), grid_x_floored.clone());
130 let y_frac = B::float_sub(grid_y.clone(), grid_y_floored.clone());
131
132 let x0 = B::float_into_int(grid_x_floored.clone());
134 let y0 = B::float_into_int(grid_y_floored.clone());
135 let x1 = B::float_into_int(B::float_add_scalar(grid_x_floored, 1.0f32.elem()));
136 let y1 = B::float_into_int(B::float_add_scalar(grid_y_floored, 1.0f32.elem()));
137
138 let (mask_00, mask_01, mask_10, mask_11) = if padding_mode == GridSamplePaddingMode::Zeros {
140 let x0_valid = B::int_greater_equal_elem(x0.clone(), 0.elem());
141 let x0_valid = B::bool_and(
142 x0_valid,
143 B::int_lower_elem(x0.clone(), (w_in as i32).elem()),
144 );
145 let x1_valid = B::int_greater_equal_elem(x1.clone(), 0.elem());
146 let x1_valid = B::bool_and(
147 x1_valid,
148 B::int_lower_elem(x1.clone(), (w_in as i32).elem()),
149 );
150 let y0_valid = B::int_greater_equal_elem(y0.clone(), 0.elem());
151 let y0_valid = B::bool_and(
152 y0_valid,
153 B::int_lower_elem(y0.clone(), (h_in as i32).elem()),
154 );
155 let y1_valid = B::int_greater_equal_elem(y1.clone(), 0.elem());
156 let y1_valid = B::bool_and(
157 y1_valid,
158 B::int_lower_elem(y1.clone(), (h_in as i32).elem()),
159 );
160
161 (
162 Some(B::bool_and(x0_valid.clone(), y0_valid.clone())),
163 Some(B::bool_and(x0_valid.clone(), y1_valid.clone())),
164 Some(B::bool_and(x1_valid.clone(), y0_valid)),
165 Some(B::bool_and(x1_valid, y1_valid)),
166 )
167 } else {
168 (None, None, None, None)
169 };
170
171 let x0_clamped = B::int_clamp(x0, 0.elem(), ((w_in - 1) as i32).elem());
173 let x1_clamped = B::int_clamp(x1, 0.elem(), ((w_in - 1) as i32).elem());
174 let y0_clamped = B::int_clamp(y0, 0.elem(), ((h_in - 1) as i32).elem());
175 let y1_clamped = B::int_clamp(y1, 0.elem(), ((h_in - 1) as i32).elem());
176
177 let y0_idx = B::int_reshape(y0_clamped.clone(), Shape::new([n, 1, h_out, w_out, 1]));
179 let y0_idx = B::int_expand(y0_idx, Shape::new([n, c, h_out, w_out, w_in]));
180 let y1_idx = B::int_reshape(y1_clamped.clone(), Shape::new([n, 1, h_out, w_out, 1]));
181 let y1_idx = B::int_expand(y1_idx, Shape::new([n, c, h_out, w_out, w_in]));
182
183 let x0_idx = B::int_reshape(x0_clamped, Shape::new([n, 1, h_out, w_out, 1]));
184 let x0_idx = B::int_expand(x0_idx, Shape::new([n, c, h_out, w_out, 1]));
185 let x1_idx = B::int_reshape(x1_clamped, Shape::new([n, 1, h_out, w_out, 1]));
186 let x1_idx = B::int_expand(x1_idx, Shape::new([n, c, h_out, w_out, 1]));
187
188 let tensor = B::float_reshape(tensor, Shape::new([n, c, h_in, 1, w_in]));
190 let tensor = B::float_expand(tensor, Shape::new([n, c, h_in, w_out, w_in]));
191
192 let sample_00 = B::float_gather(2, tensor.clone(), y0_idx.clone());
194 let sample_00 = B::float_gather(4, sample_00, x0_idx.clone());
195
196 let sample_01 = B::float_gather(2, tensor.clone(), y1_idx.clone());
197 let sample_01 = B::float_gather(4, sample_01, x0_idx.clone());
198
199 let sample_10 = B::float_gather(2, tensor.clone(), y0_idx);
200 let sample_10 = B::float_gather(4, sample_10, x1_idx.clone());
201
202 let sample_11 = B::float_gather(2, tensor, y1_idx);
203 let sample_11 = B::float_gather(4, sample_11, x1_idx);
204
205 let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out]));
207 let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out]));
208 let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out]));
209 let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out]));
210
211 let (sample_00, sample_01, sample_10, sample_11) =
213 if padding_mode == GridSamplePaddingMode::Zeros {
214 let mask_00 = mask_00.unwrap();
215 let mask_01 = mask_01.unwrap();
216 let mask_10 = mask_10.unwrap();
217 let mask_11 = mask_11.unwrap();
218
219 let mask_00_inv = B::bool_not(mask_00);
220 let mask_00_inv = B::bool_reshape(mask_00_inv, Shape::new([n, 1, h_out, w_out]));
221 let mask_00_inv = B::bool_expand(mask_00_inv, Shape::new([n, c, h_out, w_out]));
222 let mask_01_inv = B::bool_not(mask_01);
223 let mask_01_inv = B::bool_reshape(mask_01_inv, Shape::new([n, 1, h_out, w_out]));
224 let mask_01_inv = B::bool_expand(mask_01_inv, Shape::new([n, c, h_out, w_out]));
225 let mask_10_inv = B::bool_not(mask_10);
226 let mask_10_inv = B::bool_reshape(mask_10_inv, Shape::new([n, 1, h_out, w_out]));
227 let mask_10_inv = B::bool_expand(mask_10_inv, Shape::new([n, c, h_out, w_out]));
228 let mask_11_inv = B::bool_not(mask_11);
229 let mask_11_inv = B::bool_reshape(mask_11_inv, Shape::new([n, 1, h_out, w_out]));
230 let mask_11_inv = B::bool_expand(mask_11_inv, Shape::new([n, c, h_out, w_out]));
231
232 (
233 B::float_mask_fill(sample_00, mask_00_inv, 0.0f32.elem()),
234 B::float_mask_fill(sample_01, mask_01_inv, 0.0f32.elem()),
235 B::float_mask_fill(sample_10, mask_10_inv, 0.0f32.elem()),
236 B::float_mask_fill(sample_11, mask_11_inv, 0.0f32.elem()),
237 )
238 } else {
239 (sample_00, sample_01, sample_10, sample_11)
240 };
241
242 let one_minus_x = B::float_neg(x_frac.clone());
244 let one_minus_x = B::float_add_scalar(one_minus_x, 1.0f32.elem());
245
246 let one_minus_y = B::float_neg(y_frac.clone());
247 let one_minus_y = B::float_add_scalar(one_minus_y, 1.0f32.elem());
248
249 let weight_00 = B::float_mul(one_minus_x.clone(), one_minus_y.clone());
250 let weight_01 = B::float_mul(one_minus_x.clone(), y_frac.clone());
251 let weight_10 = B::float_mul(x_frac.clone(), one_minus_y);
252 let weight_11 = B::float_mul(x_frac, y_frac);
253
254 let result = B::float_mul(sample_00, weight_00);
256 let result = B::float_add(result, B::float_mul(sample_01, weight_01));
257 let result = B::float_add(result, B::float_mul(sample_10, weight_10));
258
259 B::float_add(result, B::float_mul(sample_11, weight_11))
260}
261
262fn reflect_coordinates<B: Backend>(coords: FloatTensor<B>, size: f64) -> FloatTensor<B> {
266 let max_val = (size - 1.0) as f32;
272
273 let coords = B::float_abs(coords);
275
276 B::float_clamp(coords, 0.0f32.elem(), max_val.elem())
280}