burn_tensor/tensor/ops/modules/
grid_sample.rs

1use crate::{ElementConversion, Shape, Slice, TensorMetadata, backend::Backend, ops::FloatTensor};
2use alloc::vec;
3
4/// Default implementation of float_grid_sample_2d with bilinear interpolation and border padding
5///
6/// # Arguments
7///
8/// * `tensor` - The tensor being sampled from, shape (N, C, H_in, W_in)
9/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
10///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
11///
12/// # Returns
13///
14/// A tensor with shape (N, C, H_out, W_out)
15pub fn float_grid_sample_2d_bilinear<B: Backend>(
16    tensor: FloatTensor<B>,
17    grid: FloatTensor<B>,
18) -> FloatTensor<B> {
19    let n = tensor.shape().dims[0];
20    let c = tensor.shape().dims[1];
21    let h_in = tensor.shape().dims[2];
22    let w_in = tensor.shape().dims[3];
23    let h_out = grid.shape().dims[1];
24    let w_out = grid.shape().dims[2];
25
26    let x_max_half = (w_in - 1) as f64 / 2.0;
27    let y_max_half = (h_in - 1) as f64 / 2.0;
28
29    // Clamp grid
30    let grid = B::float_clamp(grid, (-1_f32).elem(), (1_f32).elem());
31
32    // Separate x and y coordinates
33    // shape: (N, H_out, W_out, 1)
34    let grid_x_slice = vec![
35        Slice::new(0, Some(n as isize), 1),
36        Slice::new(0, Some(h_out as isize), 1),
37        Slice::new(0, Some(w_out as isize), 1),
38        Slice::new(0, Some(1), 1),
39    ];
40    let grid_y_slice = vec![
41        Slice::new(0, Some(n as isize), 1),
42        Slice::new(0, Some(h_out as isize), 1),
43        Slice::new(0, Some(w_out as isize), 1),
44        Slice::new(1, Some(2), 1),
45    ];
46
47    let grid_x = B::float_slice(grid.clone(), &grid_x_slice);
48    let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out]));
49    let grid_y = B::float_slice(grid.clone(), &grid_y_slice);
50    let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out]));
51
52    // Scale grid locations from [-1, 1] and [-1, 1] to [0..W_out] and [0..H_out]
53    let grid_x = B::float_mul_scalar(grid_x, x_max_half.elem());
54    let grid_x = B::float_add_scalar(grid_x, x_max_half.elem());
55    let grid_y = B::float_mul_scalar(grid_y, x_max_half.elem());
56    let grid_y = B::float_add_scalar(grid_y, y_max_half.elem());
57
58    // Get low and high x locations
59    let grid_x_floored = B::float_floor(grid_x.clone());
60    let grid_x_plus_one = B::float_floor(B::float_add_scalar(grid_x.clone(), 1.elem()));
61    let x_indices_low = B::float_into_int(grid_x_floored.clone());
62    let x_indices_high = B::float_into_int(grid_x_plus_one.clone());
63
64    // Get low and high x locations
65    let grid_y_floored = B::float_floor(grid_y.clone());
66    let grid_y_plus_one = B::float_floor(B::float_add_scalar(grid_y.clone(), 1.elem()));
67    let y_indices_low = B::float_into_int(grid_y_floored.clone());
68    let y_indices_high = B::float_into_int(grid_y_plus_one.clone());
69
70    // Clamp locations: border padding
71    let x_indices_low = B::int_clamp(x_indices_low, 0.elem(), ((w_in - 1) as u32).elem());
72    let x_indices_high = B::int_clamp(x_indices_high, 0.elem(), ((w_in - 1) as u32).elem());
73    let y_indices_low = B::int_clamp(y_indices_low, 0.elem(), ((h_in - 1) as u32).elem());
74    let y_indices_high = B::int_clamp(y_indices_high, 0.elem(), ((h_in - 1) as u32).elem());
75
76    // Needs shape (N, C, H_out, W_out, W_in) for the first gather operationd
77    let y_indices_low = B::int_reshape(y_indices_low, Shape::new([n, 1, h_out, w_out, 1]));
78    let y_indices_low = B::int_expand(y_indices_low, Shape::new([n, c, h_out, w_out, w_in]));
79    let y_indices_high = B::int_reshape(y_indices_high, Shape::new([n, 1, h_out, w_out, 1]));
80    let y_indices_high = B::int_expand(y_indices_high, Shape::new([n, c, h_out, w_out, w_in]));
81
82    // Needs shape (N, C, H_out, W_out, 1) for the second gather operation
83    let x_indices_low = B::int_reshape(x_indices_low, Shape::new([n, 1, h_out, w_out, 1]));
84    let x_indices_low = B::int_expand(x_indices_low, Shape::new([n, c, h_out, w_out, 1]));
85    let x_indices_high = B::int_reshape(x_indices_high, Shape::new([n, 1, h_out, w_out, 1]));
86    let x_indices_high = B::int_expand(x_indices_high, Shape::new([n, c, h_out, w_out, 1]));
87
88    // Reshape tensor for gather operation
89    let tensor = B::float_reshape(tensor, Shape::new([n, c, h_in, 1, w_in]));
90    let tensor = B::float_expand(tensor, Shape::new([n, c, h_in, w_out, w_in]));
91
92    // Gather on x and y. Watch out for the shapes
93    let sample_00 = B::float_gather(2, tensor.clone(), y_indices_low.clone());
94    let sample_00 = B::float_gather(4, sample_00, x_indices_low.clone());
95
96    let sample_01 = B::float_gather(2, tensor.clone(), y_indices_high.clone());
97    let sample_01 = B::float_gather(4, sample_01, x_indices_low.clone());
98
99    let sample_10 = B::float_gather(2, tensor.clone(), y_indices_low.clone());
100    let sample_10 = B::float_gather(4, sample_10, x_indices_high.clone());
101
102    let sample_11 = B::float_gather(2, tensor, y_indices_high);
103    let sample_11 = B::float_gather(4, sample_11, x_indices_high);
104
105    // Reshape to (N, C, H_out, W_out) for multiplying with weights
106    let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out]));
107    let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out]));
108    let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out]));
109    let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out]));
110
111    // Weights for bilinear interp
112    let weight_00 = B::float_mul(
113        B::float_sub(grid_x_plus_one.clone(), grid_x.clone()),
114        B::float_sub(grid_y_plus_one.clone(), grid_y.clone()),
115    );
116    let weight_10 = B::float_mul(
117        B::float_sub(grid_x.clone(), grid_x_floored.clone()),
118        B::float_sub(grid_y_plus_one.clone(), grid_y.clone()),
119    );
120    let weight_01 = B::float_mul(
121        B::float_sub(grid_x_plus_one.clone(), grid_x.clone()),
122        B::float_sub(grid_y.clone(), grid_y_floored.clone()),
123    );
124    let weight_11 = B::float_mul(
125        B::float_sub(grid_x.clone(), grid_x_floored),
126        B::float_sub(grid_y.clone(), grid_y_floored),
127    );
128
129    // Bilinear interp
130    let sample_0 = B::float_add(
131        B::float_mul(sample_00, weight_00),
132        B::float_mul(sample_01, weight_01),
133    );
134    let sample_1 = B::float_add(
135        B::float_mul(sample_10, weight_10),
136        B::float_mul(sample_11, weight_11),
137    );
138    B::float_add(sample_0, sample_1)
139}