use std::mem::MaybeUninit;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensorView, NdTensorViewMut, SliceItem, Tensor, TensorView};
use crate::buffer_pool::BufferPool;
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
OutputTypesContext,
};
use crate::ops::map_value_view;
use crate::value::ValueView;
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum PadMode {
Constant,
Reflect,
Edge,
Wrap,
}
pub fn pad<T: Copy + Default + PartialEq>(
pool: &BufferPool,
input: TensorView<T>,
padding: &NdTensorView<i32, 1>,
mode: PadMode,
const_val: T,
) -> Result<Tensor<T>, OpError> {
if padding.size(0) != input.ndim() * 2 {
return Err(OpError::InvalidValue(
"padding length should be 2 * input dims",
));
}
if !padding.iter().all(|x| *x >= 0) {
return Err(OpError::InvalidValue("Pad only supports positive pads"));
}
let out_shape: Vec<_> = input
.shape()
.iter()
.enumerate()
.map(|(i, size)| {
let start_pad = padding[i] as usize;
let end_pad = padding[[input.ndim() + i]] as usize;
start_pad + size + end_pad
})
.collect();
if out_shape == input.shape() {
return Ok(input.to_tensor_in(pool));
}
let output = match mode {
PadMode::Constant => {
let non_pad_region: Vec<SliceItem> = input
.shape()
.iter()
.enumerate()
.map(|(i, size)| {
let start_pad = padding[i] as usize;
(start_pad..start_pad + size).into()
})
.collect();
let mut output = if const_val == T::default() {
Tensor::zeros_in(pool, &out_shape)
} else {
Tensor::full_in(pool, &out_shape, const_val)
};
output
.slice_mut(non_pad_region.as_slice())
.copy_from(&input);
output
}
PadMode::Reflect | PadMode::Edge | PadMode::Wrap => {
const PAD_DIMS: usize = 2;
let batch_dims = input.ndim().saturating_sub(PAD_DIMS);
if out_shape[..batch_dims] != input.shape()[..batch_dims] {
return Err(OpError::UnsupportedValue(
"Pad only supports non-constant padding of last 2 dims",
));
}
if input.shape()[batch_dims..].contains(&0) {
return Err(OpError::InvalidValue(
"Padded dimension for non-constant padding is empty",
));
}
let pad_dims = input.ndim() - batch_dims;
let (pad_top, pad_left) = if pad_dims == 1 {
(0, padding[[batch_dims]] as usize)
} else {
(
padding[[batch_dims]] as usize,
padding[[batch_dims + 1]] as usize,
)
};
let mut input = input.view();
let mut output = Tensor::uninit_in(pool, &out_shape);
while input.ndim() < PAD_DIMS {
input.insert_axis(0);
output.insert_axis(0);
}
for (out_img, in_img) in output
.inner_iter_mut::<PAD_DIMS>()
.zip(input.inner_iter::<PAD_DIMS>())
{
match mode {
PadMode::Reflect => {
fill_pad(out_img, in_img, pad_top, pad_left, ReflectPad);
}
PadMode::Edge => {
fill_pad(out_img, in_img, pad_top, pad_left, EdgePad);
}
PadMode::Wrap => {
fill_pad(out_img, in_img, pad_top, pad_left, WrapPad);
}
PadMode::Constant => unreachable!(),
}
}
while output.ndim() > out_shape.len() {
output.remove_axis(0);
}
unsafe { output.assume_init() }
}
};
Ok(output)
}
fn fill_pad<T: Copy, P: PadSource>(
mut dest: NdTensorViewMut<MaybeUninit<T>, 2>,
src: NdTensorView<T, 2>,
pad_top: usize,
pad_left: usize,
src_index: P,
) {
let out_rows = dest.size(0);
let out_cols = dest.size(1);
let src_rows = src.size(0);
let src_cols = src.size(1);
for y in 0..out_rows {
let src_y = src_index.src_index(y, src_rows, pad_top);
debug_assert!(src_y < src_rows);
for x in 0..out_cols {
let src_x = src_index.src_index(x, src_cols, pad_left);
debug_assert!(src_x < src_cols);
unsafe {
dest.get_unchecked_mut([y, x])
.write(*src.get_unchecked([src_y, src_x]));
}
}
}
}
unsafe trait PadSource {
fn src_index(&self, dest: usize, len: usize, pad_start: usize) -> usize;
}
struct ReflectPad;
unsafe impl PadSource for ReflectPad {
fn src_index(&self, dest: usize, len: usize, pad_start: usize) -> usize {
let x = dest as isize;
let len = len as isize;
let pad_start = pad_start as isize;
let src_x_start = pad_start - x;
let src_x_mid = x - pad_start;
let src_x_end = len - (x - len - pad_start) - 2;
let src_x = if x < pad_start {
src_x_start
} else if x < len + pad_start {
src_x_mid
} else {
src_x_end
};
src_x.rem_euclid(len) as usize
}
}
struct EdgePad;
unsafe impl PadSource for EdgePad {
fn src_index(&self, dest: usize, len: usize, pad_start: usize) -> usize {
let len = len as isize;
let dest = dest as isize;
let src = dest - pad_start as isize;
src.clamp(0, len - 1) as usize
}
}
struct WrapPad;
unsafe impl PadSource for WrapPad {
fn src_index(&self, dest: usize, len: usize, pad_start: usize) -> usize {
let len = len as isize;
let dest = dest as isize;
let pad_start = pad_start as isize;
(dest - pad_start).rem_euclid(len) as usize
}
}
#[derive(Debug)]
pub struct Pad {
pub mode: PadMode,
}
impl Operator for Pad {
fn name(&self) -> &str {
"Pad"
}
fn max_inputs(&self) -> Option<usize> {
Some(4)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let inputs = ctx.inputs();
let input = inputs.require(0)?;
let pads = inputs.require_as(1)?;
let axes: Option<NdTensorView<i32, 1>> = inputs.get_as(3)?;
if axes.is_some() {
return Err(OpError::UnsupportedValue(
"Pad operator does not yet support `axes` input",
));
}
map_value_view!(input, x, {
let const_val = inputs.get_as(2)?.unwrap_or_default();
pad(ctx.pool(), x, &pads, self.mode, const_val).into_op_result()
})
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
}
#[cfg(test)]
mod tests {
use std::error::Error;
use rten_tensor::prelude::*;
use rten_tensor::test_util::expect_equal;
use rten_tensor::{NdTensor, Tensor};
use rten_testing::TestCases;
use crate::buffer_pool::BufferPool;
use crate::operator::{OpError, OperatorExt};
use crate::ops::{Pad, PadMode, pad};
use crate::value::{DataType, TryFromValueError, Value, ValueType};
fn from_slice<T: Clone>(data: &[T]) -> Tensor<T> {
Tensor::from_data(&[data.len()], data.to_vec())
}
#[test]
fn test_pad() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let input = Tensor::from_data(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let expected = Tensor::from_data(
&[4, 4],
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
);
let const_pads = &[1, 1, 1, 1];
let result = pad(
&pool,
input.view(),
&const_pads.into(),
PadMode::Constant,
0.0,
)
.unwrap();
expect_equal(&result, &expected)?;
let zero_pads = &[0, 0, 0, 0];
let result = pad(
&pool,
input.view(),
&zero_pads.into(),
PadMode::Constant,
0.0,
)
.unwrap();
expect_equal(&result, &input)?;
let input = Tensor::from_data(&[1, 2, 2], vec![1, 2, 3, 4]);
let pads = &[0, 0, 0, 0, 1, 0];
let result = pad(&pool, input.view(), &pads.into(), PadMode::Constant, 0).unwrap();
assert_eq!(result.shape(), &[1, 3, 2]);
assert_eq!(result.data().unwrap(), &[1, 2, 3, 4, 0, 0]);
Ok(())
}
#[derive(Debug)]
struct Case {
input: Tensor,
pads: NdTensor<i32, 1>,
mode: PadMode,
expected: Result<Tensor, OpError>,
}
fn test_pad_mode(cases: &[Case]) {
cases.test_each(|case| {
let Case {
input,
pads,
mode,
expected,
} = case;
let pool = BufferPool::new();
let result = pad(&pool, input.view(), &pads.view(), *mode, 0.);
match (result, expected) {
(Ok(result), Ok(expected)) => {
expect_equal(&result, &expected).unwrap();
}
(result, expected) => assert_eq!(&result, expected),
}
});
}
#[test]
fn test_pad_constant() {
let cases = [
Case {
input: [[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]].into(),
pads: [0, 2, 0, 0].into(),
mode: PadMode::Constant,
expected: Ok(Tensor::from([
[0.0, 0.0, 1.0, 1.2],
[0.0, 0.0, 2.3, 3.4],
[0.0, 0.0, 4.5, 5.7],
])),
},
];
test_pad_mode(&cases);
}
#[test]
fn test_pad_reflect() {
let cases = [
Case {
input: [[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]].into(),
pads: [0, 2, 0, 0].into(),
mode: PadMode::Reflect,
expected: Ok(Tensor::from([
[1.0, 1.2, 1.0, 1.2],
[2.3, 3.4, 2.3, 3.4],
[4.5, 5.7, 4.5, 5.7],
])),
},
Case {
input: [[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]].into(),
pads: [0, 0, 0, 2].into(),
mode: PadMode::Reflect,
expected: Ok(Tensor::from([
[1.0, 1.2, 1.0, 1.2],
[2.3, 3.4, 2.3, 3.4],
[4.5, 5.7, 4.5, 5.7],
])),
},
Case {
input: [[1., 2., 3., 4., 5.]].into(),
pads: [0, 3, 0, 3].into(),
mode: PadMode::Reflect,
expected: Ok(Tensor::from([[4., 3., 2., 1., 2., 3., 4., 5., 4., 3., 2.]])),
},
Case {
input: Tensor::from([1., 2., 3., 4., 5.]).into_shape([5, 1].as_slice()),
pads: [3, 0, 3, 0].into(),
mode: PadMode::Reflect,
expected: Ok(Tensor::from([4., 3., 2., 1., 2., 3., 4., 5., 4., 3., 2.])
.into_shape([5 + 2 * 3, 1].as_slice())),
},
Case {
input: [1., 2., 3., 4.].into(),
pads: [2, 2].into(),
mode: PadMode::Reflect,
expected: Ok(Tensor::from([3., 2., 1., 2., 3., 4., 3., 2.])),
},
Case {
input: Tensor::from(2.),
pads: NdTensor::from([]),
mode: PadMode::Reflect,
expected: Ok(Tensor::from(2.)),
},
Case {
input: [[[1., 2., 3.]]].into(),
pads: [0, 0, 2, 0, 0, 0].into(),
mode: PadMode::Reflect,
expected: Ok(Tensor::from([[[3., 2., 1., 2., 3.]]])),
},
Case {
input: [[[1., 2., 3.]]].into(),
pads: [0, 0, 0, 0, 0, 2].into(),
mode: PadMode::Reflect,
expected: Ok(Tensor::from([[[1., 2., 3., 2., 1.]]])),
},
Case {
input: [[[1.], [2.], [3.]]].into(),
pads: [0, 2, 0, 0, 0, 0].into(),
mode: PadMode::Reflect,
expected: Ok(Tensor::from([[[3.], [2.], [1.], [2.], [3.]]])),
},
Case {
input: [[[1., 2., 3.]]].into(),
pads: [0, 0, 0, 2, 0, 0].into(),
mode: PadMode::Reflect,
expected: Err(OpError::UnsupportedValue(
"Pad only supports non-constant padding of last 2 dims",
)),
},
Case {
input: Tensor::zeros(&[3, 0]),
pads: NdTensor::from([0, 2, 0, 0]),
mode: PadMode::Reflect,
expected: Err(OpError::InvalidValue(
"Padded dimension for non-constant padding is empty",
)),
},
];
test_pad_mode(&cases);
}
#[test]
fn test_pad_edge() {
let cases = [
Case {
input: [[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]].into(),
pads: [0, 2, 0, 0].into(),
mode: PadMode::Edge,
expected: Ok(Tensor::from([
[1.0, 1.0, 1.0, 1.2],
[2.3, 2.3, 2.3, 3.4],
[4.5, 4.5, 4.5, 5.7],
])),
},
];
test_pad_mode(&cases);
}
#[test]
fn test_pad_wrap() {
let cases = [
Case {
input: [[1.0, 1.2], [2.3, 3.4], [4.5, 5.7]].into(),
pads: [2, 1, 1, 1].into(),
mode: PadMode::Wrap,
expected: Ok(Tensor::from([
[3.4, 2.3, 3.4, 2.3],
[5.7, 4.5, 5.7, 4.5],
[1.2, 1.0, 1.2, 1.0],
[3.4, 2.3, 3.4, 2.3],
[5.7, 4.5, 5.7, 4.5],
[1.2, 1.0, 1.2, 1.0],
])),
},
];
test_pad_mode(&cases);
}
#[test]
fn test_pad_op() -> Result<(), Box<dyn Error>> {
let input = Tensor::from_data(&[2, 2], vec![1.0, 2.0, 3.0, 4.0]);
let pads = from_slice(&[1, 1, 1, 1]);
let expected = Tensor::from_data(
&[4, 4],
vec![
0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0,
],
);
let op = Pad {
mode: PadMode::Constant,
};
let result: Tensor<f32> = op.run_simple((&input, &pads)).unwrap();
expect_equal(&result, &expected)?;
Ok(())
}
#[test]
fn test_pad_invalid_inputs() {
#[derive(Debug)]
struct Case {
input: Tensor<f32>,
pads: Tensor<i32>,
const_val: Option<Value>,
expected_error: OpError,
}
let input = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
let op = Pad {
mode: PadMode::Constant,
};
let cases = [
Case {
input: input.clone(),
pads: from_slice(&[1]),
const_val: None,
expected_error: OpError::InvalidValue("padding length should be 2 * input dims"),
},
Case {
input: input.clone(),
pads: from_slice(&[1, 1, 1, -1]),
const_val: None,
expected_error: OpError::InvalidValue("Pad only supports positive pads"),
},
Case {
input: input.clone(),
pads: from_slice(&[1, 1, 1, -1]),
const_val: Some(Tensor::from(1).into()),
expected_error: OpError::InputCastFailed {
index: 2,
error: TryFromValueError::WrongType {
actual: ValueType::Tensor(DataType::Int32),
expected: ValueType::Tensor(DataType::Float),
},
},
},
Case {
input: input.clone(),
pads: from_slice(&[1, 1, 1, -1]),
const_val: Some(from_slice(&[1.0, 2.0]).into()),
expected_error: OpError::InputCastFailed {
index: 2,
error: TryFromValueError::WrongRank {
actual: 1,
expected: 0,
},
},
},
];
cases.test_each(|case| {
let Case {
input,
pads,
const_val,
expected_error,
} = case;
let result = if let Some(const_val) = const_val {
op.run_simple::<_, Tensor<f32>>((input, pads, const_val))
} else {
op.run_simple::<_, Tensor<f32>>((input, pads))
};
assert_eq!(result.err().as_ref(), Some(expected_error));
});
}
}