Skip to main content

burn_ndarray/ops/
grid_sample.rs

1use burn_backend::ElementConversion;
2use burn_backend::ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode};
3#[cfg(not(feature = "std"))]
4#[allow(unused_imports)]
5use num_traits::Float;
6
7use ndarray::Array4;
8
9use crate::SharedArray;
10use crate::{FloatNdArrayElement, UnsafeSharedRef, iter_range_par, run_par};
11
12/// Sample a tensor using grid-based sampling.
13///
14/// # Arguments
15///
16/// * `tensor` - The tensor being sampled from, must be contiguous with shape (N, C, H_in, W_in)
17/// * `grid` - A tensor of locations, with shape (N, H_out, W_out, 2). Values are [-1, 1].
18///   A [x = -1, y = -1] means top-left, and [x = 1, y = 1] means bottom-right
19/// * `options` - Grid sampling options (mode, padding_mode, align_corners)
20///
21/// # Returns
22///
23/// A tensor with shape (N, C, H_out, W_out)
24pub(crate) fn grid_sample_2d<E: FloatNdArrayElement>(
25    tensor: SharedArray<E>,
26    grid: SharedArray<E>,
27    options: GridSampleOptions,
28) -> SharedArray<E> {
29    match options.mode {
30        InterpolateMode::Bilinear => (),
31        _ => todo!(
32            "grid_sample_2d with {:?} mode is not implemented",
33            options.mode
34        ),
35    }
36
37    let tensor = tensor.into_dimensionality::<ndarray::Ix4>().unwrap();
38    let grid = grid.into_dimensionality::<ndarray::Ix4>().unwrap();
39
40    let (batch_size, channels, height_in, width_in) = tensor.dim();
41    let (b, height_out, width_out, d) = grid.dim();
42    assert!(batch_size == b);
43    assert!(2 == d);
44
45    let mut output = Array4::zeros((batch_size, channels, height_out, width_out));
46    let unsafe_shared_out = UnsafeSharedRef::new(&mut output);
47
48    let sample_count = batch_size * channels * height_out * width_out;
49    let strides = (
50        channels * height_out * width_out,
51        height_out * width_out,
52        width_out,
53    );
54
55    let align = options.align_corners;
56    let pad_mode = options.padding_mode;
57
58    run_par!(|| {
59        iter_range_par!(0, sample_count).for_each(|id| {
60            let (b, c, y, x) = (
61                id / strides.0,
62                id % strides.0 / strides.1,
63                id % strides.1 / strides.2,
64                id % strides.2,
65            );
66
67            let sample_x = grid[(b, y, x, 0)].elem::<f64>();
68            let sample_y = grid[(b, y, x, 1)].elem::<f64>();
69
70            // Convert normalized grid coordinates [-1, 1] to pixel coordinates
71            let (px, py) = if align {
72                // align_corners=true: x_pixel = (x_norm + 1) * (width - 1) / 2
73                // Maps -1 to 0 and 1 to width - 1
74                let px = (sample_x + 1.0) * ((width_in - 1) as f64) / 2.0;
75                let py = (sample_y + 1.0) * ((height_in - 1) as f64) / 2.0;
76                (px, py)
77            } else {
78                // align_corners=false: x_pixel = (x_norm + 1) * width / 2 - 0.5
79                // Maps -1 to -0.5 and 1 to width - 0.5
80                let px = (sample_x + 1.0) * (width_in as f64) / 2.0 - 0.5;
81                let py = (sample_y + 1.0) * (height_in as f64) / 2.0 - 0.5;
82                (px, py)
83            };
84
85            // Bilinear interpolation with the specified padding mode
86            let val =
87                bilinear_interpolate(&tensor, b, c, px, py, width_in, height_in, pad_mode, align);
88
89            unsafe {
90                let output = unsafe_shared_out.get();
91                output[(b, c, y, x)] = val.elem();
92            }
93        });
94    });
95
96    output.into_dyn().into_shared()
97}
98
99/// Bilinear interpolation at a point with configurable padding mode.
100#[allow(clippy::too_many_arguments)]
101fn bilinear_interpolate<E, S>(
102    source: &ndarray::ArrayBase<S, ndarray::Dim<[usize; 4]>>,
103    b: usize,
104    c: usize,
105    x: f64,
106    y: f64,
107    width: usize,
108    height: usize,
109    padding_mode: GridSamplePaddingMode,
110    align_corners: bool,
111) -> f64
112where
113    E: FloatNdArrayElement,
114    S: ndarray::Data<Elem = E>,
115{
116    // Handle inf/nan coordinates
117    if !x.is_finite() || !y.is_finite() {
118        return match padding_mode {
119            GridSamplePaddingMode::Zeros => 0.0,
120            GridSamplePaddingMode::Border => {
121                // Clamp to center of image for inf/nan
122                let cx = ((width - 1) as f64 / 2.0).clamp(0.0, (width - 1) as f64);
123                let cy = ((height - 1) as f64 / 2.0).clamp(0.0, (height - 1) as f64);
124                source[(b, c, cy as usize, cx as usize)].elem::<f64>()
125            }
126            GridSamplePaddingMode::Reflection => 0.0, // Simplified: treat as zeros for inf/nan
127        };
128    }
129
130    // Apply padding mode to get actual sampling coordinates
131    let (x, y) = match padding_mode {
132        GridSamplePaddingMode::Border => {
133            // Clamp coordinates to valid range [0, size-1]
134            let x = x.clamp(0.0, (width - 1) as f64);
135            let y = y.clamp(0.0, (height - 1) as f64);
136            (x, y)
137        }
138        GridSamplePaddingMode::Reflection => {
139            // Reflect coordinates at boundaries
140            let x = reflect_coordinate(x, width, align_corners);
141            let y = reflect_coordinate(y, height, align_corners);
142            (x, y)
143        }
144        GridSamplePaddingMode::Zeros => (x, y), // Keep as-is, handle out-of-bounds in read
145    };
146
147    // Get the four corner indices
148    let x0 = x.floor() as i64;
149    let y0 = y.floor() as i64;
150    let x1 = x0.saturating_add(1);
151    let y1 = y0.saturating_add(1);
152
153    // Compute interpolation weights (fractional part)
154    let x_frac = x - x.floor();
155    let y_frac = y - y.floor();
156
157    // Helper to read a value based on padding mode
158    let read_value = |xi: i64, yi: i64| -> f64 {
159        match padding_mode {
160            GridSamplePaddingMode::Zeros => {
161                // Return 0 for out-of-bounds
162                if xi >= 0 && xi < width as i64 && yi >= 0 && yi < height as i64 {
163                    source[(b, c, yi as usize, xi as usize)].elem::<f64>()
164                } else {
165                    0.0
166                }
167            }
168            GridSamplePaddingMode::Border | GridSamplePaddingMode::Reflection => {
169                // Coordinates should already be in valid range after clamping/reflection
170                let xi = xi.clamp(0, (width - 1) as i64) as usize;
171                let yi = yi.clamp(0, (height - 1) as i64) as usize;
172                source[(b, c, yi, xi)].elem::<f64>()
173            }
174        }
175    };
176
177    // Read the four corners
178    let v00 = read_value(x0, y0);
179    let v01 = read_value(x0, y1);
180    let v10 = read_value(x1, y0);
181    let v11 = read_value(x1, y1);
182
183    // Bilinear interpolation weights
184    let w00 = (1.0 - x_frac) * (1.0 - y_frac);
185    let w01 = (1.0 - x_frac) * y_frac;
186    let w10 = x_frac * (1.0 - y_frac);
187    let w11 = x_frac * y_frac;
188
189    v00 * w00 + v01 * w01 + v10 * w10 + v11 * w11
190}
191
192/// Reflect a coordinate at the boundaries using a triangle wave pattern.
193///
194/// For align_corners=true: reflects within [0, size-1]
195/// For align_corners=false: reflects within [-0.5, size-0.5]
196fn reflect_coordinate(coord: f64, size: usize, align_corners: bool) -> f64 {
197    let size_f = size as f64;
198    let (min_val, max_val) = if align_corners {
199        (0.0, size_f - 1.0)
200    } else {
201        (-0.5, size_f - 0.5)
202    };
203
204    let span = max_val - min_val;
205    if span <= 0.0 {
206        return min_val;
207    }
208
209    // Triangle wave formula: span - |((x mod 2*span) - span)|
210    let period = 2.0 * span;
211    let x = (coord - min_val).abs();
212    let x_mod = x - (x / period).floor() * period;
213    span - (x_mod - span).abs() + min_val
214}