Skip to main content

burn_backend/backend/ops/modules/
grid_sample.rs

1use crate::{
2    Backend, TensorMetadata, get_device_settings,
3    ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},
4    tensor::FloatTensor,
5};
6use alloc::vec;
7use burn_std::{Shape, Slice};
8
9/// Reference implementation of grid_sample_2d that supports all options.
10///
11/// # Arguments
12///
13/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
14/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
15///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
16/// * `options` - Grid sampling options
17///
18/// # Returns
19///
20/// A tensor with shape (N, C, H_out, W_out)
21pub fn float_grid_sample_2d_ref<B: Backend>(
22    tensor: FloatTensor<B>,
23    grid: FloatTensor<B>,
24    options: GridSampleOptions,
25) -> FloatTensor<B> {
26    match options.mode {
27        InterpolateMode::Bilinear => float_grid_sample_2d_bilinear::<B>(
28            tensor,
29            grid,
30            options.padding_mode,
31            options.align_corners,
32        ),
33        _ => todo!(
34            "Default implementation for grid_sample_2d with {:?} unimplemented",
35            options.mode
36        ),
37    }
38}
39
40/// Bilinear grid sampling implementation.
41fn float_grid_sample_2d_bilinear<B: Backend>(
42    tensor: FloatTensor<B>,
43    grid: FloatTensor<B>,
44    padding_mode: GridSamplePaddingMode,
45    align_corners: bool,
46) -> FloatTensor<B> {
47    let n = tensor.shape()[0];
48    let c = tensor.shape()[1];
49    let h_in = tensor.shape()[2];
50    let w_in = tensor.shape()[3];
51    let h_out = grid.shape()[1];
52    let w_out = grid.shape()[2];
53    let spatial_in = h_in * w_in;
54    let spatial_out = h_out * w_out;
55    let device = B::float_device(&tensor);
56
57    // Separate x and y coordinates from grid
58    // shape: (N, H_out, W_out, 1)
59    let grid_x_slice = vec![
60        Slice::new(0, Some(n as isize), 1),
61        Slice::new(0, Some(h_out as isize), 1),
62        Slice::new(0, Some(w_out as isize), 1),
63        Slice::new(0, Some(1), 1),
64    ];
65    let grid_y_slice = vec![
66        Slice::new(0, Some(n as isize), 1),
67        Slice::new(0, Some(h_out as isize), 1),
68        Slice::new(0, Some(w_out as isize), 1),
69        Slice::new(1, Some(2), 1),
70    ];
71
72    let grid_x = B::float_slice(grid.clone(), &grid_x_slice);
73    let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out]));
74    let grid_y = B::float_slice(grid.clone(), &grid_y_slice);
75    let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out]));
76
77    // Convert normalized grid coordinates [-1, 1] to pixel coordinates
78    let w_in_f = w_in as f64;
79    let h_in_f = h_in as f64;
80
81    let (grid_x, grid_y) = if align_corners {
82        // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2
83        // Maps -1 to 0 and 1 to width - 1
84        let grid_x = B::float_add_scalar(grid_x, 1f32.into());
85        let grid_x = B::float_mul_scalar(grid_x, ((w_in_f - 1.0) / 2.0).into());
86
87        let grid_y = B::float_add_scalar(grid_y, 1f32.into());
88        let grid_y = B::float_mul_scalar(grid_y, ((h_in_f - 1.0) / 2.0).into());
89
90        (grid_x, grid_y)
91    } else {
92        // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5
93        // Maps -1 to -0.5 and 1 to width - 0.5
94        let grid_x = B::float_add_scalar(grid_x, 1f32.into());
95        let grid_x = B::float_mul_scalar(grid_x, (w_in_f / 2.0).into());
96        let grid_x = B::float_sub_scalar(grid_x, 0.5f32.into());
97
98        let grid_y = B::float_add_scalar(grid_y, 1f32.into());
99        let grid_y = B::float_mul_scalar(grid_y, (h_in_f / 2.0).into());
100        let grid_y = B::float_sub_scalar(grid_y, 0.5f32.into());
101
102        (grid_x, grid_y)
103    };
104
105    // Apply padding mode to coordinates
106    let (grid_x, grid_y) = match padding_mode {
107        GridSamplePaddingMode::Border => {
108            // Clamp coordinates to valid range [0, size-1]
109            let grid_x = B::float_clamp(grid_x, 0f32.into(), ((w_in - 1) as f32).into());
110            let grid_y = B::float_clamp(grid_y, 0f32.into(), ((h_in - 1) as f32).into());
111            (grid_x, grid_y)
112        }
113        GridSamplePaddingMode::Reflection => {
114            // Reflect coordinates at boundaries
115            let grid_x = reflect_coordinates::<B>(grid_x, w_in_f, align_corners);
116            let grid_y = reflect_coordinates::<B>(grid_y, h_in_f, align_corners);
117            (grid_x, grid_y)
118        }
119        GridSamplePaddingMode::Zeros => {
120            // Keep coordinates as-is, we'll mask out-of-bounds later
121            (grid_x, grid_y)
122        }
123    };
124
125    // Get floor indices for the four corners
126    let grid_x_floored = B::float_floor(grid_x.clone());
127    let grid_y_floored = B::float_floor(grid_y.clone());
128
129    // Compute interpolation weights (fractional part)
130    let x_frac = B::float_sub(grid_x.clone(), grid_x_floored.clone());
131    let y_frac = B::float_sub(grid_y.clone(), grid_y_floored.clone());
132
133    // Convert to integer indices
134    let settings = get_device_settings::<B>(&device);
135    let x0 = B::float_into_int(grid_x_floored.clone(), settings.int_dtype);
136    let y0 = B::float_into_int(grid_y_floored.clone(), settings.int_dtype);
137    let x1 = B::float_into_int(
138        B::float_add_scalar(grid_x_floored, 1f32.into()),
139        settings.int_dtype,
140    );
141    let y1 = B::float_into_int(
142        B::float_add_scalar(grid_y_floored, 1f32.into()),
143        settings.int_dtype,
144    );
145
146    // Create masks for out-of-bounds coordinates (only used for zeros padding)
147    let (mask_00, mask_01, mask_10, mask_11) = if padding_mode == GridSamplePaddingMode::Zeros {
148        let x0_valid = B::int_greater_equal_elem(x0.clone(), 0.into(), settings.bool_dtype);
149        let x0_valid = B::bool_and(
150            x0_valid,
151            B::int_lower_elem(x0.clone(), (w_in as i32).into(), settings.bool_dtype),
152        );
153        let x1_valid = B::int_greater_equal_elem(x1.clone(), 0.into(), settings.bool_dtype);
154        let x1_valid = B::bool_and(
155            x1_valid,
156            B::int_lower_elem(x1.clone(), (w_in as i32).into(), settings.bool_dtype),
157        );
158        let y0_valid = B::int_greater_equal_elem(y0.clone(), 0.into(), settings.bool_dtype);
159        let y0_valid = B::bool_and(
160            y0_valid,
161            B::int_lower_elem(y0.clone(), (h_in as i32).into(), settings.bool_dtype),
162        );
163        let y1_valid = B::int_greater_equal_elem(y1.clone(), 0.into(), settings.bool_dtype);
164        let y1_valid = B::bool_and(
165            y1_valid,
166            B::int_lower_elem(y1.clone(), (h_in as i32).into(), settings.bool_dtype),
167        );
168
169        (
170            Some(B::bool_and(x0_valid.clone(), y0_valid.clone())),
171            Some(B::bool_and(x0_valid.clone(), y1_valid.clone())),
172            Some(B::bool_and(x1_valid.clone(), y0_valid)),
173            Some(B::bool_and(x1_valid, y1_valid)),
174        )
175    } else {
176        (None, None, None, None)
177    };
178
179    // Clamp indices to valid range for gather
180    let x0_clamped = B::int_clamp(x0, 0.into(), ((w_in - 1) as i32).into());
181    let x1_clamped = B::int_clamp(x1, 0.into(), ((w_in - 1) as i32).into());
182    let y0_clamped = B::int_clamp(y0, 0.into(), ((h_in - 1) as i32).into());
183    let y1_clamped = B::int_clamp(y1, 0.into(), ((h_in - 1) as i32).into());
184
185    // Linear indices: idx = y * W_in + x
186    let w_in_scalar: i32 = w_in as i32;
187    let idx_00 = B::int_add(
188        B::int_mul_scalar(y0_clamped.clone(), w_in_scalar.into()),
189        x0_clamped.clone(),
190    );
191    let idx_01 = B::int_add(
192        B::int_mul_scalar(y1_clamped.clone(), w_in_scalar.into()),
193        x0_clamped,
194    );
195    let idx_10 = B::int_add(
196        B::int_mul_scalar(y0_clamped, w_in_scalar.into()),
197        x1_clamped.clone(),
198    );
199    let idx_11 = B::int_add(
200        B::int_mul_scalar(y1_clamped, w_in_scalar.into()),
201        x1_clamped,
202    );
203
204    // [N, 1, H_out, W_out] -> [N, 1, H_out * W_out]
205    let idx_00 = B::int_reshape(idx_00, Shape::new([n, 1, spatial_out]));
206    let idx_01 = B::int_reshape(idx_01, Shape::new([n, 1, spatial_out]));
207    let idx_10 = B::int_reshape(idx_10, Shape::new([n, 1, spatial_out]));
208    let idx_11 = B::int_reshape(idx_11, Shape::new([n, 1, spatial_out]));
209
210    // [N, 1, spatial] -> [N, C, spatial]
211    let idx_00 = B::int_expand(idx_00, Shape::new([n, c, spatial_out]));
212    let idx_01 = B::int_expand(idx_01, Shape::new([n, c, spatial_out]));
213    let idx_10 = B::int_expand(idx_10, Shape::new([n, c, spatial_out]));
214    let idx_11 = B::int_expand(idx_11, Shape::new([n, c, spatial_out]));
215
216    let tensor_flat = B::float_reshape(tensor, Shape::new([n, c, spatial_in]));
217
218    let sample_00 = B::float_gather(2, tensor_flat.clone(), idx_00);
219    let sample_01 = B::float_gather(2, tensor_flat.clone(), idx_01);
220    let sample_10 = B::float_gather(2, tensor_flat.clone(), idx_10);
221    let sample_11 = B::float_gather(2, tensor_flat, idx_11);
222
223    // Reshape samples to (N, C, H_out, W_out)
224    let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out]));
225    let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out]));
226    let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out]));
227    let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out]));
228
229    // Apply masks for zeros padding (set out-of-bounds samples to 0)
230    let (sample_00, sample_01, sample_10, sample_11) =
231        if padding_mode == GridSamplePaddingMode::Zeros {
232            let mask_00 = mask_00.unwrap();
233            let mask_01 = mask_01.unwrap();
234            let mask_10 = mask_10.unwrap();
235            let mask_11 = mask_11.unwrap();
236
237            let mask_00_inv = B::bool_not(mask_00);
238            let mask_00_inv = B::bool_reshape(mask_00_inv, Shape::new([n, 1, h_out, w_out]));
239            let mask_00_inv = B::bool_expand(mask_00_inv, Shape::new([n, c, h_out, w_out]));
240            let mask_01_inv = B::bool_not(mask_01);
241            let mask_01_inv = B::bool_reshape(mask_01_inv, Shape::new([n, 1, h_out, w_out]));
242            let mask_01_inv = B::bool_expand(mask_01_inv, Shape::new([n, c, h_out, w_out]));
243            let mask_10_inv = B::bool_not(mask_10);
244            let mask_10_inv = B::bool_reshape(mask_10_inv, Shape::new([n, 1, h_out, w_out]));
245            let mask_10_inv = B::bool_expand(mask_10_inv, Shape::new([n, c, h_out, w_out]));
246            let mask_11_inv = B::bool_not(mask_11);
247            let mask_11_inv = B::bool_reshape(mask_11_inv, Shape::new([n, 1, h_out, w_out]));
248            let mask_11_inv = B::bool_expand(mask_11_inv, Shape::new([n, c, h_out, w_out]));
249
250            (
251                B::float_mask_fill(sample_00, mask_00_inv, 0f32.into()),
252                B::float_mask_fill(sample_01, mask_01_inv, 0f32.into()),
253                B::float_mask_fill(sample_10, mask_10_inv, 0f32.into()),
254                B::float_mask_fill(sample_11, mask_11_inv, 0f32.into()),
255            )
256        } else {
257            (sample_00, sample_01, sample_10, sample_11)
258        };
259
260    // Compute bilinear interpolation weights
261    let one_minus_x = B::float_neg(x_frac.clone());
262    let one_minus_x = B::float_add_scalar(one_minus_x, 1f32.into());
263
264    let one_minus_y = B::float_neg(y_frac.clone());
265    let one_minus_y = B::float_add_scalar(one_minus_y, 1f32.into());
266
267    let weight_00 = B::float_mul(one_minus_x.clone(), one_minus_y.clone());
268    let weight_01 = B::float_mul(one_minus_x.clone(), y_frac.clone());
269    let weight_10 = B::float_mul(x_frac.clone(), one_minus_y);
270    let weight_11 = B::float_mul(x_frac, y_frac);
271
272    // Bilinear interpolation
273    let result = B::float_mul(sample_00, weight_00);
274    let result = B::float_add(result, B::float_mul(sample_01, weight_01));
275    let result = B::float_add(result, B::float_mul(sample_10, weight_10));
276
277    B::float_add(result, B::float_mul(sample_11, weight_11))
278}
279
280/// Reflect coordinates at boundaries using a triangle wave pattern.
281///
282/// For align_corners=true: reflects within [0, size-1]
283/// For align_corners=false: reflects within [-0.5, size-0.5]
284fn reflect_coordinates<B: Backend>(
285    coords: FloatTensor<B>,
286    size: f64,
287    align_corners: bool,
288) -> FloatTensor<B> {
289    let (min_val, max_val) = if align_corners {
290        (0.0f32, (size - 1.0) as f32)
291    } else {
292        (-0.5f32, (size - 0.5) as f32)
293    };
294
295    let span = max_val - min_val;
296    if span <= 0.0 {
297        // Edge case: size is 1, just return min_val everywhere
298        let zeros = B::float_mul_scalar(coords, 0f32.into());
299        return B::float_add_scalar(zeros, min_val.into());
300    }
301
302    // Triangle wave formula: span - |((x mod 2*span) - span)| + min_val
303    let period = 2.0 * span;
304
305    // x = abs(coord - min_val)
306    let x = B::float_sub_scalar(coords, min_val.into());
307    let x = B::float_abs(x);
308
309    // x_mod = x - floor(x / period) * period
310    let x_div = B::float_div_scalar(x.clone(), period.into());
311    let x_div_floor = B::float_floor(x_div);
312    let x_mod = B::float_sub(x, B::float_mul_scalar(x_div_floor, period.into()));
313
314    // result = span - abs(x_mod - span) + min_val
315    let diff = B::float_sub_scalar(x_mod, span.into());
316    let abs_diff = B::float_abs(diff);
317    let reflected = B::float_sub_scalar(abs_diff, span.into());
318    let reflected = B::float_neg(reflected);
319    B::float_add_scalar(reflected, min_val.into())
320}