1use burn_backend::ElementConversion;
2use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
3#[cfg(not(feature = "std"))]
4#[allow(unused_imports)]
5use num_traits::Float;
6
7use ndarray::Array4;
8
9use crate::SharedArray;
10use crate::{FloatNdArrayElement, UnsafeSharedRef, iter_range_par, run_par};
11
12pub(crate) fn grid_sample_2d<E: FloatNdArrayElement>(
25 tensor: SharedArray<E>,
26 grid: SharedArray<E>,
27 options: GridSampleOptions,
28) -> SharedArray<E> {
29 match options.mode {
30 InterpolateMode::Bilinear => (),
31 _ => todo!(
32 "grid_sample_2d with {:?} mode is not implemented",
33 options.mode
34 ),
35 }
36
37 let tensor = tensor.into_dimensionality::<ndarray::Ix4>().unwrap();
38 let grid = grid.into_dimensionality::<ndarray::Ix4>().unwrap();
39
40 let (batch_size, channels, height_in, width_in) = tensor.dim();
41 let (b, height_out, width_out, d) = grid.dim();
42 assert!(batch_size == b);
43 assert!(2 == d);
44
45 let mut output = Array4::zeros((batch_size, channels, height_out, width_out));
46 let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
47
48 let sample_count = batch_size * channels * height_out * width_out;
49 let strides = (
50 channels * height_out * width_out,
51 height_out * width_out,
52 width_out,
53 );
54
55 let align = options.align_corners;
56 let pad_mode = options.padding_mode;
57
58 run_par!(|| {
59 iter_range_par!(0, sample_count).for_each(|id| {
60 let (b, c, y, x) = (
61 id / strides.0,
62 id % strides.0 / strides.1,
63 id % strides.1 / strides.2,
64 id % strides.2,
65 );
66
67 let sample_x = grid[(b, y, x, 0)].elem::<f64>();
68 let sample_y = grid[(b, y, x, 1)].elem::<f64>();
69
70 let (px, py) = if align {
72 let px = (sample_x + 1.0) * ((width_in - 1) as f64) / 2.0;
75 let py = (sample_y + 1.0) * ((height_in - 1) as f64) / 2.0;
76 (px, py)
77 } else {
78 let px = (sample_x + 1.0) * (width_in as f64) / 2.0 - 0.5;
81 let py = (sample_y + 1.0) * (height_in as f64) / 2.0 - 0.5;
82 (px, py)
83 };
84
85 let val =
87 bilinear_interpolate(&tensor, b, c, px, py, width_in, height_in, pad_mode, align);
88
89 unsafe {
90 let output = unsafe_shared_out.get();
91 output[(b, c, y, x)] = val.elem();
92 }
93 });
94 });
95
96 output.into_dyn().into_shared()
97}
98
99#[allow(clippy::too_many_arguments)]
101fn bilinear_interpolate<E, S>(
102 source: &ndarray::ArrayBase<S, ndarray::Dim<[usize; 4]>>,
103 b: usize,
104 c: usize,
105 x: f64,
106 y: f64,
107 width: usize,
108 height: usize,
109 padding_mode: GridSamplePaddingMode,
110 align_corners: bool,
111) -> f64
112where
113 E: FloatNdArrayElement,
114 S: ndarray::Data<Elem = E>,
115{
116 if !x.is_finite() || !y.is_finite() {
118 return match padding_mode {
119 GridSamplePaddingMode::Zeros => 0.0,
120 GridSamplePaddingMode::Border => {
121 let cx = ((width - 1) as f64 / 2.0).clamp(0.0, (width - 1) as f64);
123 let cy = ((height - 1) as f64 / 2.0).clamp(0.0, (height - 1) as f64);
124 source[(b, c, cy as usize, cx as usize)].elem::<f64>()
125 }
126 GridSamplePaddingMode::Reflection => 0.0, };
128 }
129
130 let (x, y) = match padding_mode {
132 GridSamplePaddingMode::Border => {
133 let x = x.clamp(0.0, (width - 1) as f64);
135 let y = y.clamp(0.0, (height - 1) as f64);
136 (x, y)
137 }
138 GridSamplePaddingMode::Reflection => {
139 let x = reflect_coordinate(x, width, align_corners);
141 let y = reflect_coordinate(y, height, align_corners);
142 (x, y)
143 }
144 GridSamplePaddingMode::Zeros => (x, y), };
146
147 let x0 = x.floor() as i64;
149 let y0 = y.floor() as i64;
150 let x1 = x0.saturating_add(1);
151 let y1 = y0.saturating_add(1);
152
153 let x_frac = x - x.floor();
155 let y_frac = y - y.floor();
156
157 let read_value = |xi: i64, yi: i64| -> f64 {
159 match padding_mode {
160 GridSamplePaddingMode::Zeros => {
161 if xi >= 0 && xi < width as i64 && yi >= 0 && yi < height as i64 {
163 source[(b, c, yi as usize, xi as usize)].elem::<f64>()
164 } else {
165 0.0
166 }
167 }
168 GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => {
169 let xi = xi.clamp(0, (width - 1) as i64) as usize;
171 let yi = yi.clamp(0, (height - 1) as i64) as usize;
172 source[(b, c, yi, xi)].elem::<f64>()
173 }
174 }
175 };
176
177 let v00 = read_value(x0, y0);
179 let v01 = read_value(x0, y1);
180 let v10 = read_value(x1, y0);
181 let v11 = read_value(x1, y1);
182
183 let w00 = (1.0 - x_frac) * (1.0 - y_frac);
185 let w01 = (1.0 - x_frac) * y_frac;
186 let w10 = x_frac * (1.0 - y_frac);
187 let w11 = x_frac * y_frac;
188
189 v00 * w00 + v01 * w01 + v10 * w10 + v11 * w11
190}
191
192fn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 {
197 let size_f = size as f64;
198 let (min_val, max_val) = if align_corners {
199 (0.0, size_f - 1.0)
200 } else {
201 (-0.5, size_f - 0.5)
202 };
203
204 let span = max_val - min_val;
205 if span <= 0.0 {
206 return min_val;
207 }
208
209 let period = 2.0 * span;
211 let x = (coord - min_val).abs();
212 let x_mod = x - (x / period).floor() * period;
213 span - (x_mod - span).abs() + min_val
214}