1use crate::{ElementConversion, Shape, Slice, TensorMetadata, backend::Backend, ops::FloatTensor};
2use alloc::vec;
3
4pub fn float_grid_sample_2d_bilinear<B: Backend>(
16 tensor: FloatTensor<B>,
17 grid: FloatTensor<B>,
18) -> FloatTensor<B> {
19 let n = tensor.shape().dims[0];
20 let c = tensor.shape().dims[1];
21 let h_in = tensor.shape().dims[2];
22 let w_in = tensor.shape().dims[3];
23 let h_out = grid.shape().dims[1];
24 let w_out = grid.shape().dims[2];
25
26 let x_max_half = (w_in - 1) as f64 / 2.0;
27 let y_max_half = (h_in - 1) as f64 / 2.0;
28
29 let grid = B::float_clamp(grid, (-1_f32).elem(), (1_f32).elem());
31
32 let grid_x_slice = vec![
35 Slice::new(0, Some(n as isize), 1),
36 Slice::new(0, Some(h_out as isize), 1),
37 Slice::new(0, Some(w_out as isize), 1),
38 Slice::new(0, Some(1), 1),
39 ];
40 let grid_y_slice = vec![
41 Slice::new(0, Some(n as isize), 1),
42 Slice::new(0, Some(h_out as isize), 1),
43 Slice::new(0, Some(w_out as isize), 1),
44 Slice::new(1, Some(2), 1),
45 ];
46
47 let grid_x = B::float_slice(grid.clone(), &grid_x_slice);
48 let grid_x = B::float_reshape(grid_x, Shape::new([n, 1, h_out, w_out]));
49 let grid_y = B::float_slice(grid.clone(), &grid_y_slice);
50 let grid_y = B::float_reshape(grid_y, Shape::new([n, 1, h_out, w_out]));
51
52 let grid_x = B::float_mul_scalar(grid_x, x_max_half.elem());
54 let grid_x = B::float_add_scalar(grid_x, x_max_half.elem());
55 let grid_y = B::float_mul_scalar(grid_y, x_max_half.elem());
56 let grid_y = B::float_add_scalar(grid_y, y_max_half.elem());
57
58 let grid_x_floored = B::float_floor(grid_x.clone());
60 let grid_x_plus_one = B::float_floor(B::float_add_scalar(grid_x.clone(), 1.elem()));
61 let x_indices_low = B::float_into_int(grid_x_floored.clone());
62 let x_indices_high = B::float_into_int(grid_x_plus_one.clone());
63
64 let grid_y_floored = B::float_floor(grid_y.clone());
66 let grid_y_plus_one = B::float_floor(B::float_add_scalar(grid_y.clone(), 1.elem()));
67 let y_indices_low = B::float_into_int(grid_y_floored.clone());
68 let y_indices_high = B::float_into_int(grid_y_plus_one.clone());
69
70 let x_indices_low = B::int_clamp(x_indices_low, 0.elem(), ((w_in - 1) as u32).elem());
72 let x_indices_high = B::int_clamp(x_indices_high, 0.elem(), ((w_in - 1) as u32).elem());
73 let y_indices_low = B::int_clamp(y_indices_low, 0.elem(), ((h_in - 1) as u32).elem());
74 let y_indices_high = B::int_clamp(y_indices_high, 0.elem(), ((h_in - 1) as u32).elem());
75
76 let y_indices_low = B::int_reshape(y_indices_low, Shape::new([n, 1, h_out, w_out, 1]));
78 let y_indices_low = B::int_expand(y_indices_low, Shape::new([n, c, h_out, w_out, w_in]));
79 let y_indices_high = B::int_reshape(y_indices_high, Shape::new([n, 1, h_out, w_out, 1]));
80 let y_indices_high = B::int_expand(y_indices_high, Shape::new([n, c, h_out, w_out, w_in]));
81
82 let x_indices_low = B::int_reshape(x_indices_low, Shape::new([n, 1, h_out, w_out, 1]));
84 let x_indices_low = B::int_expand(x_indices_low, Shape::new([n, c, h_out, w_out, 1]));
85 let x_indices_high = B::int_reshape(x_indices_high, Shape::new([n, 1, h_out, w_out, 1]));
86 let x_indices_high = B::int_expand(x_indices_high, Shape::new([n, c, h_out, w_out, 1]));
87
88 let tensor = B::float_reshape(tensor, Shape::new([n, c, h_in, 1, w_in]));
90 let tensor = B::float_expand(tensor, Shape::new([n, c, h_in, w_out, w_in]));
91
92 let sample_00 = B::float_gather(2, tensor.clone(), y_indices_low.clone());
94 let sample_00 = B::float_gather(4, sample_00, x_indices_low.clone());
95
96 let sample_01 = B::float_gather(2, tensor.clone(), y_indices_high.clone());
97 let sample_01 = B::float_gather(4, sample_01, x_indices_low.clone());
98
99 let sample_10 = B::float_gather(2, tensor.clone(), y_indices_low.clone());
100 let sample_10 = B::float_gather(4, sample_10, x_indices_high.clone());
101
102 let sample_11 = B::float_gather(2, tensor, y_indices_high);
103 let sample_11 = B::float_gather(4, sample_11, x_indices_high);
104
105 let sample_00 = B::float_reshape(sample_00, Shape::new([n, c, h_out, w_out]));
107 let sample_01 = B::float_reshape(sample_01, Shape::new([n, c, h_out, w_out]));
108 let sample_10 = B::float_reshape(sample_10, Shape::new([n, c, h_out, w_out]));
109 let sample_11 = B::float_reshape(sample_11, Shape::new([n, c, h_out, w_out]));
110
111 let weight_00 = B::float_mul(
113 B::float_sub(grid_x_plus_one.clone(), grid_x.clone()),
114 B::float_sub(grid_y_plus_one.clone(), grid_y.clone()),
115 );
116 let weight_10 = B::float_mul(
117 B::float_sub(grid_x.clone(), grid_x_floored.clone()),
118 B::float_sub(grid_y_plus_one.clone(), grid_y.clone()),
119 );
120 let weight_01 = B::float_mul(
121 B::float_sub(grid_x_plus_one.clone(), grid_x.clone()),
122 B::float_sub(grid_y.clone(), grid_y_floored.clone()),
123 );
124 let weight_11 = B::float_mul(
125 B::float_sub(grid_x.clone(), grid_x_floored),
126 B::float_sub(grid_y.clone(), grid_y_floored),
127 );
128
129 let sample_0 = B::float_add(
131 B::float_mul(sample_00, weight_00),
132 B::float_mul(sample_01, weight_01),
133 );
134 let sample_1 = B::float_add(
135 B::float_mul(sample_10, weight_10),
136 B::float_mul(sample_11, weight_11),
137 );
138 B::float_add(sample_0, sample_1)
139}