use crate::internal::*;
use crate::ops::math::round_ties_to_even;
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum InterpolationMode {
Bilinear,
Nearest,
Bicubic,
}
impl InterpolationMode {
pub fn as_str(&self) -> &'static str {
match self {
InterpolationMode::Bilinear => "bilinear",
InterpolationMode::Nearest => "nearest",
InterpolationMode::Bicubic => "bicubic",
}
}
pub fn parse(s: &str) -> TractResult<Self> {
Ok(match s {
"bilinear" => InterpolationMode::Bilinear,
"nearest" => InterpolationMode::Nearest,
"bicubic" => InterpolationMode::Bicubic,
_ => bail!("Unsupported GridSample mode: {}", s),
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum PaddingMode {
Zeros,
Border,
Reflection,
}
impl PaddingMode {
pub fn as_str(&self) -> &'static str {
match self {
PaddingMode::Zeros => "zeros",
PaddingMode::Border => "border",
PaddingMode::Reflection => "reflection",
}
}
pub fn parse(s: &str) -> TractResult<Self> {
Ok(match s {
"zeros" => PaddingMode::Zeros,
"border" => PaddingMode::Border,
"reflection" => PaddingMode::Reflection,
_ => bail!("Unsupported GridSample padding_mode: {}", s),
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct GridSample {
pub mode: InterpolationMode,
pub padding_mode: PaddingMode,
pub align_corners: bool,
}
impl GridSample {
fn denormalize(&self, coord: f32, size: usize) -> f32 {
if self.align_corners {
(coord + 1.0) / 2.0 * (size as f32 - 1.0)
} else {
((coord + 1.0) * size as f32 - 1.0) / 2.0
}
}
fn bounds(&self, size: usize) -> (f32, f32) {
if self.align_corners { (0.0, size as f32 - 1.0) } else { (-0.5, size as f32 - 0.5) }
}
fn pixel_at_nd(
&self,
x: &tract_ndarray::ArrayViewD<'_, f32>,
batch: usize,
channel: usize,
coords: &[isize],
spatial_sizes: &[usize],
) -> f32 {
match self.padding_mode {
PaddingMode::Zeros => {
for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
if c < 0 || c >= s as isize {
return 0.0;
}
}
let mut idx = vec![batch, channel];
idx.extend(coords.iter().map(|&c| c as usize));
x[idx.as_slice()]
}
PaddingMode::Border => {
let mut idx = vec![batch, channel];
for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
idx.push((c.max(0) as usize).min(s - 1));
}
x[idx.as_slice()]
}
PaddingMode::Reflection => {
let mut idx = vec![batch, channel];
for (&c, &s) in coords.iter().zip(spatial_sizes.iter()) {
let (lo, hi) = self.bounds(s);
idx.push(gs_reflect(c as f32, lo, hi) as usize);
}
x[idx.as_slice()]
}
}
}
fn apply_padding(&self, coord: f32, lo: f32, hi: f32) -> f32 {
match self.padding_mode {
PaddingMode::Border => coord.clamp(0.0, hi + lo),
PaddingMode::Reflection => gs_reflect(coord, lo, hi),
PaddingMode::Zeros => coord,
}
}
fn is_oob(&self, coords: &[f32], bounds: &[(f32, f32)]) -> bool {
coords.iter().zip(bounds.iter()).any(|(&c, &(lo, hi))| c < lo || c > hi)
}
fn pad_coords(&self, coords: &mut [f32], bounds: &[(f32, f32)]) {
for (c, &(lo, hi)) in coords.iter_mut().zip(bounds.iter()) {
*c = self.apply_padding(*c, lo, hi);
}
}
fn sample_nd(
&self,
x: &tract_ndarray::ArrayViewD<'_, f32>,
batch: usize,
channel: usize,
pixel_coords: &[f32],
spatial_sizes: &[usize],
) -> f32 {
let ndim = pixel_coords.len();
let bounds: Vec<(f32, f32)> = spatial_sizes.iter().map(|&s| self.bounds(s)).collect();
match self.mode {
InterpolationMode::Nearest => {
let mut coords: Vec<f32> =
pixel_coords.iter().map(|&c| round_ties_to_even(c)).collect();
if self.is_oob(&coords, &bounds) {
self.pad_coords(&mut coords, &bounds);
}
let icoords: Vec<isize> = coords.iter().map(|&c| c as isize).collect();
self.pixel_at_nd(x, batch, channel, &icoords, spatial_sizes)
}
InterpolationMode::Bilinear => {
let mut coords: Vec<f32> = pixel_coords.to_vec();
if self.is_oob(&coords, &bounds) {
self.pad_coords(&mut coords, &bounds);
}
let num_corners = 1 << ndim;
let mut result = 0.0f32;
for corner in 0..num_corners {
let mut weight = 1.0f32;
let mut icoords = Vec::with_capacity(ndim);
for (d, &c) in coords.iter().enumerate() {
let lo = c.floor() as isize;
if (corner >> d) & 1 == 0 {
icoords.push(lo);
weight *= (lo + 1) as f32 - c;
} else {
icoords.push(lo + 1);
weight *= c - lo as f32;
}
}
result += weight * self.pixel_at_nd(x, batch, channel, &icoords, spatial_sizes);
}
result
}
InterpolationMode::Bicubic => {
assert!(ndim == 2, "Bicubic interpolation only supports 2D spatial dimensions");
let (mut px, mut py) = (pixel_coords[0], pixel_coords[1]);
if self.is_oob(&[px, py], &bounds) {
px = self.apply_padding(px, bounds[0].0, bounds[0].1);
py = self.apply_padding(py, bounds[1].0, bounds[1].1);
}
let x0 = px.floor() as isize - 1;
let y0 = py.floor() as isize - 1;
let dx = px - x0 as f32 - 1.0;
let dy = py - y0 as f32 - 1.0;
let mut p = [[0.0f32; 4]; 4];
for (h, row) in p.iter_mut().enumerate() {
for (w, val) in row.iter_mut().enumerate() {
*val = self.pixel_at_nd(
x,
batch,
channel,
&[x0 + w as isize, y0 + h as isize],
spatial_sizes,
);
}
}
bicubic_interpolate(&p, dx, dy)
}
}
}
}
fn gs_reflect(x: f32, x_min: f32, x_max: f32) -> f32 {
let rng = x_max - x_min;
if rng == 0.0 {
return x_min;
}
if x < x_min {
let dx = x_min - x;
let n = (dx / rng) as i32;
let r = dx - n as f32 * rng;
if n % 2 == 0 { x_min + r } else { x_max - r }
} else if x > x_max {
let dx = x - x_max;
let n = (dx / rng) as i32;
let r = dx - n as f32 * rng;
if n % 2 == 0 { x_max - r } else { x_min + r }
} else {
x
}
}
fn bicubic_interpolate(p: &[[f32; 4]; 4], dx: f32, dy: f32) -> f32 {
let mut v = [0.0f32; 4];
let mut coeffs = [0.0f32; 4];
cubic_coeffs(dx, &mut coeffs);
for i in 0..4 {
v[i] =
coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
}
cubic_coeffs(dy, &mut coeffs);
coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]
}
fn cubic_coeffs(x: f32, coeffs: &mut [f32; 4]) {
let a = -0.75f32;
let xp1 = x + 1.0;
let xm1 = 1.0 - x;
let xm2 = 2.0 - x;
coeffs[0] = ((a * xp1 - 5.0 * a) * xp1 + 8.0 * a) * xp1 - 4.0 * a;
coeffs[1] = ((a + 2.0) * x - (a + 3.0)) * x * x + 1.0;
coeffs[2] = ((a + 2.0) * xm1 - (a + 3.0)) * xm1 * xm1 + 1.0;
coeffs[3] = ((a * xm2 - 5.0 * a) * xm2 + 8.0 * a) * xm2 - 4.0 * a;
}
impl Op for GridSample {
fn name(&self) -> StaticName {
"GridSample".into()
}
op_as_typed_op!();
}
impl EvalOp for GridSample {
fn is_stateless(&self) -> bool {
true
}
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let (x, grid) = args_2!(inputs);
let input_dt = x.datum_type();
let x_tensor = x.into_tensor();
let x_cow = x_tensor.cast_to::<f32>()?;
let x = x_cow.to_plain_array_view::<f32>()?;
let grid_tensor = grid.into_tensor();
let grid_cow = grid_tensor.cast_to::<f32>()?;
let grid = grid_cow.to_plain_array_view::<f32>()?;
let x_shape = x.shape();
let grid_shape = grid.shape();
let rank = x_shape.len();
let spatial_rank = rank - 2;
let n_batch = x_shape[0];
let n_channel = x_shape[1];
let spatial_sizes: Vec<usize> = x_shape[2..].to_vec();
let mut output_shape = vec![n_batch, n_channel];
output_shape.extend_from_slice(&grid_shape[1..rank - 1]);
let output = tract_ndarray::ArrayD::from_shape_fn(&*output_shape, |idx| -> f32 {
let batch = idx[0];
let channel = idx[1];
let out_spatial: Vec<usize> = (2..rank).map(|d| idx[d]).collect();
let mut grid_idx = vec![batch];
grid_idx.extend_from_slice(&out_spatial);
grid_idx.push(0);
let mut pixel_coords = Vec::with_capacity(spatial_rank);
for (d, &size) in spatial_sizes.iter().enumerate() {
*grid_idx.last_mut().unwrap() = spatial_rank - 1 - d;
let norm_coord = grid[grid_idx.as_slice()];
pixel_coords.push(self.denormalize(norm_coord, size));
}
self.sample_nd(&x, batch, channel, &pixel_coords, &spatial_sizes)
});
Ok(tvec!(output.into_tensor().cast_to_dt(input_dt)?.into_owned().into_tvalue()))
}
}
impl TypedOp for GridSample {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
let x_shape = &inputs[0].shape;
let grid_shape = &inputs[1].shape;
let rank = x_shape.len();
let mut output_shape: TVec<TDim> = tvec![x_shape[0].clone(), x_shape[1].clone()];
for d in 1..rank - 1 {
output_shape.push(grid_shape[d].clone());
}
Ok(tvec!(inputs[0].datum_type.fact(&output_shape)))
}
}