use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicUsize, Ordering};
use rayon::prelude::*;
use rten_shape_inference::ops as shape_ops;
use rten_simd::SimdOp;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor, TensorView, TensorViewMut};
use smallvec::SmallVec;
use crate::buffer_pool::BufferPool;
use crate::infer_shapes::{InferShapes, impl_infer_shapes};
use crate::operator::{
IntoOpResult, OpError, OpRunContext, Operator, OutputList, OutputType, OutputTypeList,
OutputTypesContext, static_dims,
};
use crate::ops::{Padding, check_value};
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub enum RoundMode {
#[default]
Floor,
Ceil,
}
#[derive(Copy, Clone)]
enum AxisPadding {
Same,
Fixed { start: usize, end: usize },
}
impl AxisPadding {
fn from_2d(pad: Padding) -> Result<[AxisPadding; 2], OpError> {
match pad {
Padding::Same => Ok([AxisPadding::Same, AxisPadding::Same]),
Padding::Fixed(pads) => {
let [pad_top, pad_left, pad_bottom, pad_right]: [usize; 4] = pads
.as_slice()
.try_into()
.map_err(|_| OpError::InvalidValue("Expected 4 padding values"))?;
let h_pad = AxisPadding::Fixed {
start: pad_top,
end: pad_bottom,
};
let w_pad = AxisPadding::Fixed {
start: pad_left,
end: pad_right,
};
Ok([h_pad, w_pad])
}
}
}
}
fn output_size_and_padding_for_axis(
in_size: usize,
kernel_size: usize,
stride: usize,
padding: AxisPadding,
dilation: usize,
round_mode: RoundMode,
) -> Result<(usize, usize, usize), OpError> {
check_value!(dilation > 0, InvalidValue, "Dilations must be > 0");
check_value!(kernel_size > 0, InvalidValue, "Kernel size must be > 0");
check_value!(stride > 0, InvalidValue, "Strides must be > 0");
match padding {
AxisPadding::Same => {
let out_size = in_size.div_ceil(stride);
let pad_total = ((out_size - 1) * stride + (kernel_size - 1) * dilation + 1)
.saturating_sub(in_size);
let pad_start = pad_total / 2;
let pad_end = pad_total.div_ceil(2);
Ok((out_size, pad_start, pad_end))
}
AxisPadding::Fixed {
start: pad_start,
end: pad_end,
} => {
let padded_in_size = in_size + pad_start + pad_end;
let dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1);
if padded_in_size < dilated_kernel_size {
return Err(OpError::InvalidValue("Input too small for kernel size"));
}
let mut out_size = match round_mode {
RoundMode::Floor => {
(padded_in_size - dilation * (kernel_size - 1) - 1) / stride + 1
}
RoundMode::Ceil => {
(padded_in_size - dilation * (kernel_size - 1) - 1 + stride - 1)
.div_ceil(stride)
+ 1
}
};
if round_mode == RoundMode::Ceil && (out_size - 1) * stride >= in_size + pad_start {
out_size -= 1;
}
Ok((out_size, pad_start, pad_end))
}
}
}
pub fn calc_output_size_and_padding(
in_size: (usize, usize),
kernel_size: (usize, usize),
strides: (usize, usize),
padding: Padding,
dilations: Option<(usize, usize)>,
round_mode: RoundMode,
) -> Result<(usize, usize, [usize; 4]), OpError> {
let (in_h, in_w) = in_size;
let (k_h, k_w) = kernel_size;
let (stride_h, stride_w) = strides;
let (dilation_y, dilation_x) = dilations.unwrap_or((1, 1));
let [h_pad, w_pad] = AxisPadding::from_2d(padding)?;
let (out_h, pad_top, pad_bottom) =
output_size_and_padding_for_axis(in_h, k_h, stride_h, h_pad, dilation_y, round_mode)?;
let (out_w, pad_left, pad_right) =
output_size_and_padding_for_axis(in_w, k_w, stride_w, w_pad, dilation_x, round_mode)?;
Ok((out_h, out_w, [pad_top, pad_left, pad_bottom, pad_right]))
}
const CHAN_GROUP_SIZE: usize = 4;
fn pool_impl<T: Copy + Send, F: Fn(T, T) -> T + Sync, A: Fn(T, usize) -> T + Sync>(
pool: &BufferPool,
input: TensorView<T>,
kernel_size: &[usize],
strides: &[usize],
padding: Padding,
fold_init: T,
fold: &F,
average: &A,
round_mode: RoundMode,
) -> Result<Tensor<T>, OpError>
where
for<'a> TensorViewMut<'a, T>: Send,
for<'a> TensorView<'a, T>: Send,
for<'a> &'a T: Sync,
{
let spatial_dims = input.ndim().saturating_sub(2);
if kernel_size.len() != spatial_dims {
return Err(OpError::InvalidValue(
"kernel_size len does not match spatial dims",
));
}
if strides.len() != spatial_dims {
return Err(OpError::InvalidValue(
"strides len does not match spatial dims",
));
}
match spatial_dims {
1 => {
let mut input_2d = input.view();
input_2d.insert_axis(2); let padding_2d = padding.expand_1d_to_2d()?;
let mut result_2d = pool_impl(
pool,
input_2d,
&[1, kernel_size[0]],
&[1, strides[0]],
padding_2d,
fold_init,
fold,
average,
round_mode,
)?;
result_2d.remove_axis(2); return Ok(result_2d);
}
2 => { }
_ => {
return Err(OpError::UnsupportedValue(
"Only inputs with 1 or 2 spatial dims are supported",
));
}
}
let kernel_size: [usize; 2] = kernel_size.try_into().unwrap();
let strides: [usize; 2] = strides.try_into().unwrap();
let input = static_dims!(input, 4, "NCHW")?;
let [batch, in_c, in_h, in_w] = input.shape();
let (out_h, out_w, fixed_padding) = calc_output_size_and_padding(
(in_h, in_w),
(kernel_size[0], kernel_size[1]),
(strides[0], strides[1]),
padding,
None,
round_mode,
)?;
let [pad_top, pad_left, _pad_bottom, _pad_right] = fixed_padding;
let mut output = NdTensor::uninit_in(pool, [batch, in_c, out_h, out_w]);
fn pool_chans<T: Copy, F: Fn(T, T) -> T, A: Fn(T, usize) -> T, const N: usize>(
mut out: NdTensorViewMut<MaybeUninit<T>, 3>,
in_view: NdTensorView<T, 3>,
chans: [usize; N],
[kernel_h, kernel_w]: [usize; 2],
[stride_h, stride_w]: [usize; 2],
[pad_top, pad_left]: [usize; 2],
fold_init: T,
fold: F,
average: A,
) {
let [out_chans, out_h, out_w] = out.shape();
let [in_chans, in_h, in_w] = in_view.shape();
assert!(chans.into_iter().all(|c| c < out_chans && c < in_chans));
for out_y in 0..out_h {
let min_in_y = out_y * stride_h;
let max_in_y = min_in_y + kernel_h.saturating_sub(1);
let y_non_pad_region = min_in_y >= pad_top && max_in_y < in_h + pad_top;
for out_x in 0..out_w {
let min_in_x = out_x * stride_w;
let max_in_x = min_in_x + kernel_w.saturating_sub(1);
let x_non_pad_region = min_in_x >= pad_left && max_in_x < in_w + pad_left;
let mut accumulator = [fold_init; N];
let mut non_pad_elements = 0;
if y_non_pad_region && x_non_pad_region {
non_pad_elements = kernel_h * kernel_w;
for k_y in 0..kernel_h {
for k_x in 0..kernel_w {
let in_y = out_y * stride_h + k_y;
let in_x = out_x * stride_w + k_x;
for (i, chan) in chans.into_iter().enumerate() {
let val = unsafe {
*in_view.get_unchecked([chan, in_y - pad_top, in_x - pad_left])
};
accumulator[i] = fold(accumulator[i], val);
}
}
}
} else {
for k_y in 0..kernel_h {
for k_x in 0..kernel_w {
let in_y = out_y * stride_h + k_y;
let in_x = out_x * stride_w + k_x;
if in_y >= pad_top
&& in_y < in_h + pad_top
&& in_x >= pad_left
&& in_x < in_w + pad_left
{
for (i, chan) in chans.into_iter().enumerate() {
let val = unsafe {
*in_view.get_unchecked([
chan,
in_y - pad_top,
in_x - pad_left,
])
};
accumulator[i] = fold(accumulator[i], val);
}
non_pad_elements += 1;
}
}
}
}
for (i, chan) in chans.into_iter().enumerate() {
unsafe {
out.get_unchecked_mut([chan, out_y, out_x])
.write(average(accumulator[i], non_pad_elements));
}
}
}
}
}
let accum_init_val = || fold_init;
let n_init = AtomicUsize::new(0);
output
.axis_iter_mut(0)
.into_par_iter()
.zip(input.axis_iter(0))
.for_each(|(mut out_item, in_item)| {
let [_, out_h, out_w] = out_item.shape();
const N: usize = CHAN_GROUP_SIZE;
for chan in (0..in_c).step_by(N) {
if in_c - chan < N {
break;
}
pool_chans(
out_item.view_mut(),
in_item,
[chan, chan + 1, chan + 2, chan + 3],
kernel_size,
strides,
[pad_top, pad_left],
accum_init_val(),
fold,
average,
);
n_init.fetch_add(N * out_h * out_w, Ordering::SeqCst);
}
for chan in (in_c - in_c % N)..in_c {
pool_chans(
out_item.view_mut(),
in_item,
[chan],
kernel_size,
strides,
[pad_top, pad_left],
accum_init_val(),
fold,
average,
);
n_init.fetch_add(out_h * out_w, Ordering::SeqCst);
}
});
assert!(n_init.load(Ordering::SeqCst) == output.len());
let output = unsafe { output.assume_init() };
Ok(output.into())
}
pub fn average_pool(
pool: &BufferPool,
input: TensorView,
kernel_size: &[usize],
strides: &[usize],
padding: Padding,
count_include_pad: bool,
round_mode: RoundMode,
) -> Result<Tensor, OpError> {
let kernel_len: usize = kernel_size.iter().product();
pool_impl(
pool,
input,
kernel_size,
strides,
padding,
0.,
&|acc, x| acc + x,
&|acc, non_pad_elements| {
if count_include_pad {
acc / (kernel_len as f32)
} else {
acc / (non_pad_elements as f32)
}
},
round_mode,
)
}
#[derive(Debug)]
pub struct AveragePool {
pub kernel_size: SmallVec<[usize; 2]>,
pub padding: Padding,
pub count_include_pad: bool,
pub strides: SmallVec<[usize; 2]>,
pub ceil_mode: bool,
}
impl Operator for AveragePool {
fn name(&self) -> &str {
"AveragePool"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require_as(0)?;
average_pool(
ctx.pool(),
input,
&self.kernel_size,
&self.strides,
self.padding.clone(),
self.count_include_pad,
if self.ceil_mode {
RoundMode::Ceil
} else {
RoundMode::Floor
},
)
.into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
}
impl_infer_shapes!(
AveragePool,
op,
shape_ops::Pool {
strides: &op.strides,
dilations: &[1, 1],
kernel_size: &op.kernel_size,
padding: op.padding.as_shape_inference_padding(),
}
);
fn global_pool<T: Clone + Send + Sync>(
pool: &BufferPool,
input: TensorView<T>,
kernel: &(dyn Fn(&[T]) -> T + Send + Sync),
) -> Result<Tensor<T>, OpError> {
if input.ndim() < 2 {
return Err(OpError::InvalidValue("Input must have at least 2 dims"));
}
let batch = input.size(0);
let chan = input.size(1);
let mut out_shape: SmallVec<[usize; 4]> = [batch, chan].into_iter().collect();
out_shape.resize(input.ndim(), 1);
let n_elem = input.shape().iter().skip(2).product();
let input = input.reshaped_in(pool, [batch, chan, n_elem]);
let n_out = batch * chan;
let mut out_data = pool.alloc::<T>(n_out);
let out_uninit = &mut out_data.spare_capacity_mut()[..n_out];
input
.lanes(2)
.into_par_iter()
.zip(out_uninit.par_iter_mut())
.for_each(|(chan_data, out)| {
let chan_slice = chan_data.as_slice().unwrap();
let reduced = kernel(chan_slice);
out.write(reduced);
});
unsafe {
out_data.set_len(n_out);
}
Ok(Tensor::from_data(&out_shape, out_data))
}
pub fn global_average_pool(pool: &BufferPool, input: TensorView) -> Result<Tensor, OpError> {
global_pool(pool, input, &|chan_data| {
let sum = rten_vecmath::Sum::new(chan_data).dispatch();
sum / chan_data.len() as f32
})
}
#[derive(Debug)]
pub struct GlobalAveragePool {}
impl Operator for GlobalAveragePool {
fn name(&self) -> &str {
"GlobalAveragePool"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require_as(0)?;
global_average_pool(ctx.pool(), input).into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::GlobalPool)
}
}
pub fn global_max_pool(pool: &BufferPool, input: TensorView) -> Result<Tensor, OpError> {
global_pool(pool, input, &|chan_data| {
rten_vecmath::MaxNum::new(chan_data).dispatch()
})
}
#[derive(Debug)]
pub struct GlobalMaxPool {}
impl Operator for GlobalMaxPool {
fn name(&self) -> &str {
"GlobalMaxPool"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require_as(0)?;
global_max_pool(ctx.pool(), input).into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(&shape_ops::GlobalPool)
}
}
pub fn max_pool(
pool: &BufferPool,
input: TensorView,
kernel_size: &[usize],
strides: &[usize],
padding: Padding,
round_mode: RoundMode,
) -> Result<Tensor, OpError> {
pool_impl(
pool,
input,
kernel_size,
strides,
padding,
f32::NEG_INFINITY,
&|acc, x| acc.max(x),
&|x, _non_pad_count| x,
round_mode,
)
}
#[derive(Debug)]
pub struct MaxPool {
pub kernel_size: SmallVec<[usize; 2]>,
pub padding: Padding,
pub strides: SmallVec<[usize; 2]>,
pub ceil_mode: bool,
}
impl Operator for MaxPool {
fn name(&self) -> &str {
"MaxPool"
}
fn max_inputs(&self) -> Option<usize> {
Some(1)
}
fn run(&self, ctx: &OpRunContext) -> Result<OutputList, OpError> {
let input = ctx.inputs().require_as(0)?;
max_pool(
ctx.pool(),
input,
&self.kernel_size,
&self.strides,
self.padding.clone(),
if self.ceil_mode {
RoundMode::Ceil
} else {
RoundMode::Floor
},
)
.into_op_result()
}
fn output_types(&self, _ctx: &OutputTypesContext) -> Option<OutputTypeList> {
Some([OutputType::CopyFromInput(0)].into())
}
fn as_infer_shapes(&self) -> Option<&dyn InferShapes> {
Some(self)
}
}
impl_infer_shapes!(
MaxPool,
op,
shape_ops::Pool {
strides: &op.strides,
dilations: &[1, 1],
kernel_size: &op.kernel_size,
padding: op.padding.as_shape_inference_padding(),
}
);
#[cfg(test)]
mod tests {
use std::error::Error;
use rten_tensor::prelude::*;
use rten_tensor::test_util::expect_equal;
use rten_tensor::{Tensor, TensorView};
use rten_testing::TestCases;
use super::{
RoundMode, average_pool, calc_output_size_and_padding, global_average_pool,
global_max_pool, max_pool,
};
use crate::buffer_pool::BufferPool;
use crate::ops::tests::expect_eq_1e4;
use crate::ops::{OpError, Padding};
#[test]
fn test_average_pool() {
let input = Tensor::from([
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.1, 0.2, 0.3, 0.4],
[0.6, 0.7, 0.8, 0.9],
])
.into_shape([1, 1, 4, 4])
.into_dyn();
let input_1d = input.slice((.., .., 0, ..));
#[derive(Debug)]
struct Case<'a> {
input: TensorView<'a>,
kernel_size: Vec<usize>,
strides: Vec<usize>,
padding: Padding,
expected: Tensor,
}
let cases = [
Case {
input: input.view(),
kernel_size: [2, 2].into(),
strides: [2, 2].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(&[1, 1, 2, 2], vec![0.35, 0.55, 0.4, 0.6]),
},
Case {
input: input.view(),
kernel_size: [4, 4].into(),
strides: [4, 4].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(&[1, 1, 1, 1], vec![0.475]),
},
Case {
input: input.view(),
kernel_size: [2, 4].into(),
strides: [2, 4].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(&[1, 1, 2, 1], vec![0.45, 0.5]),
},
Case {
input: input.view(),
kernel_size: [2, 2].into(),
strides: [1, 2].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(
&[1, 1, 3, 2],
vec![
0.35, 0.55, 0.35, 0.55, 0.4, 0.6, ],
),
},
Case {
input: input.view(),
kernel_size: [2, 2].into(),
strides: [2, 1].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(
&[1, 1, 2, 3],
vec![
0.35, 0.45, 0.55, 0.4, 0.5, 0.6, ],
),
},
Case {
input: input_1d.view(),
kernel_size: [2].into(),
strides: [2].into(),
padding: [0, 0].into(),
expected: Tensor::from_data(&[1, 1, 2], vec![0.15, 0.35]),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let result = average_pool(
&pool,
case.input.view(),
&case.kernel_size,
&case.strides,
case.padding.clone(),
false,
RoundMode::default(),
)
.unwrap();
expect_equal(&result, &case.expected).unwrap();
})
}
#[test]
fn test_average_pool_padding() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let n_chans = super::CHAN_GROUP_SIZE + 1;
let input = Tensor::from([
[0.0809, 0.5529, 0.1534, 0.7507],
[0.4698, 0.7771, 0.9896, 0.4873],
[0.9750, 0.5160, 0.6419, 0.3670],
[0.4101, 0.3762, 0.9689, 0.4389],
]);
let [rows, cols]: [usize; 2] = input.shape().try_into().unwrap();
let input = input.broadcast([1, n_chans, rows, cols]);
let expected = Tensor::from([
[0.0809, 0.3531, 0.7507],
[0.7224, 0.7312, 0.4271],
[0.4101, 0.6725, 0.4389],
]);
let [rows, cols]: [usize; 2] = expected.shape().try_into().unwrap();
let expected = expected.broadcast([1, n_chans, rows, cols]);
let result = average_pool(
&pool,
input.as_dyn(),
&[2, 2],
&[2, 2],
[1, 1, 1, 1].into(),
false,
RoundMode::default(),
)
.unwrap();
expect_eq_1e4(&result.view(), &expected.as_dyn())?;
let expected_include_pad = Tensor::from([
[0.0202, 0.1766, 0.1877],
[0.3612, 0.7312, 0.2136],
[0.1025, 0.3363, 0.1097],
])
.broadcast([1, n_chans, 3, 3])
.to_tensor();
let result = average_pool(
&pool,
input.as_dyn(),
&[2, 2],
&[2, 2],
[1, 1, 1, 1].into(),
true,
RoundMode::default(),
)
.unwrap();
expect_eq_1e4(&result.view(), &expected_include_pad.as_dyn())?;
Ok(())
}
#[test]
fn test_global_average_pool() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let input = Tensor::from_data(&[1, 2, 2, 2], vec![1., 2., 3., 4., 10., 20., 30., 40.]);
let expected = Tensor::from_data(&[1, 2, 1, 1], vec![2.5, 25.]);
let result = global_average_pool(&pool, input.view()).unwrap();
expect_equal(&result, &expected)?;
Ok(())
}
#[test]
fn test_global_max_pool() -> Result<(), Box<dyn Error>> {
let pool = BufferPool::new();
let input = Tensor::from_data(&[1, 2, 2, 2], vec![1., 2., 3., 4., 10., 20., 30., 40.]);
let expected = Tensor::from_data(&[1, 2, 1, 1], vec![4.0, 40.]);
let result = global_max_pool(&pool, input.view()).unwrap();
expect_equal(&result, &expected)?;
Ok(())
}
#[test]
fn test_global_pool_invalid_input() {
let pool = BufferPool::new();
let input = Tensor::from([1., 2., 3., 4.]);
let err = global_max_pool(&pool, input.view()).err().unwrap();
assert_eq!(
err,
OpError::InvalidValue("Input must have at least 2 dims")
);
}
#[test]
fn test_max_pool() {
let input = Tensor::from([
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8],
[0.1, 0.2, 0.3, 0.4],
[0.6, 0.7, 0.8, 0.9],
])
.into_shape([1, 1, 4, 4])
.into_dyn();
let input_1d = input.slice((.., .., 0, ..));
#[derive(Debug)]
struct Case<'a> {
input: TensorView<'a>,
kernel_size: Vec<usize>,
strides: Vec<usize>,
padding: Padding,
expected: Tensor,
}
let cases = [
Case {
input: input.view(),
kernel_size: [2, 2].into(),
strides: [2, 2].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(&[1, 1, 2, 2], vec![0.6, 0.8, 0.7, 0.9]),
},
Case {
input: input.view(),
kernel_size: [4, 4].into(),
strides: [4, 4].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(&[1, 1, 1, 1], vec![0.9]),
},
Case {
input: input.view(),
kernel_size: [2, 4].into(),
strides: [2, 4].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(&[1, 1, 2, 1], vec![0.8, 0.9]),
},
Case {
input: input.view(),
kernel_size: [2, 2].into(),
strides: [1, 2].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(
&[1, 1, 3, 2],
vec![
0.6, 0.8, 0.6, 0.8, 0.7, 0.9, ],
),
},
Case {
input: input.view(),
kernel_size: [2, 2].into(),
strides: [2, 1].into(),
padding: [0, 0, 0, 0].into(),
expected: Tensor::from_data(
&[1, 1, 2, 3],
vec![
0.6, 0.7, 0.8, 0.7, 0.8, 0.9, ],
),
},
Case {
input: input_1d.view(),
kernel_size: [2].into(),
strides: [2].into(),
padding: [0, 0].into(),
expected: Tensor::from_data(&[1, 1, 2], vec![0.2, 0.4]),
},
];
cases.test_each(|case| {
let pool = BufferPool::new();
let result = max_pool(
&pool,
case.input.view(),
&case.kernel_size,
&case.strides,
case.padding.clone(),
RoundMode::default(),
)
.unwrap();
expect_equal(&result, &case.expected).unwrap();
})
}
#[test]
fn test_max_pool_padding() {
let pool = BufferPool::new();
let input = Tensor::zeros(&[1, 1, 9, 9]);
let rm = RoundMode::default();
let result = max_pool(
&pool,
input.view(),
&[2, 2],
&[2, 2],
[0, 0, 0, 0].into(),
rm,
)
.unwrap();
assert_eq!(result.shape(), &[1, 1, 4, 4]);
let result = max_pool(
&pool,
input.view(),
&[2, 2],
&[2, 2],
[1, 1, 1, 1].into(),
rm,
)
.unwrap();
assert_eq!(result.shape(), &[1, 1, 5, 5]);
let result = max_pool(
&pool,
input.view(),
&[2, 2],
&[2, 2],
[2, 2, 2, 2].into(),
rm,
)
.unwrap();
assert_eq!(result.shape(), &[1, 1, 6, 6]);
let result = max_pool(&pool, input.view(), &[2, 2], &[2, 2], Padding::Same, rm).unwrap();
assert_eq!(result.shape(), &[1, 1, 5, 5]);
let result = max_pool(&pool, input.view(), &[2, 2], &[3, 3], Padding::Same, rm).unwrap();
assert_eq!(result.shape(), &[1, 1, 3, 3]);
}
#[test]
fn test_calc_output_size_and_padding() {
#[derive(Debug)]
struct Case {
in_size: (usize, usize),
kernel_size: (usize, usize),
dilations: (usize, usize),
strides: (usize, usize),
padding: Padding,
round_mode: RoundMode,
expected: Result<(usize, usize, [usize; 4]), OpError>,
}
impl Default for Case {
fn default() -> Self {
Case {
in_size: (5, 5),
kernel_size: (3, 3),
dilations: (1, 1),
strides: (1, 1),
padding: [0, 0, 0, 0].into(),
round_mode: RoundMode::Floor,
expected: Err(OpError::InvalidValue("default")),
}
}
}
let cases = [
Case {
expected: Ok((3, 3, [0, 0, 0, 0])),
..Default::default()
},
Case {
padding: [1, 1, 1, 1].into(),
expected: Ok((5, 5, [1, 1, 1, 1])),
..Default::default()
},
Case {
strides: (2, 2),
expected: Ok((2, 2, [0, 0, 0, 0])),
..Default::default()
},
Case {
dilations: (2, 2),
expected: Ok((1, 1, [0, 0, 0, 0])),
..Default::default()
},
Case {
in_size: (1, 20),
kernel_size: (1, 3),
padding: Padding::Same,
expected: Ok((1, 20, [0, 1, 0, 1])),
..Default::default()
},
Case {
in_size: (9, 9),
strides: (3, 3),
kernel_size: (2, 2),
padding: Padding::Same,
expected: Ok((3, 3, [0, 0, 0, 0])),
..Default::default()
},
Case {
in_size: (8, 8),
strides: (2, 2),
round_mode: RoundMode::Ceil,
expected: Ok((4, 4, [0, 0, 0, 0])),
..Default::default()
},
Case {
in_size: (8, 8),
strides: (2, 2),
round_mode: RoundMode::Floor,
expected: Ok((3, 3, [0, 0, 0, 0])),
..Default::default()
},
Case {
in_size: (7, 7),
strides: (2, 2),
round_mode: RoundMode::Ceil,
padding: Padding::Same,
expected: Ok((4, 4, [1, 1, 1, 1])),
..Default::default()
},
Case {
in_size: (7, 7),
strides: (2, 2),
round_mode: RoundMode::Floor,
padding: Padding::Same,
expected: Ok((4, 4, [1, 1, 1, 1])),
..Default::default()
},
Case {
in_size: (12, 12),
kernel_size: (1, 1),
strides: (2, 2),
round_mode: RoundMode::Ceil,
expected: Ok((6, 6, [0, 0, 0, 0])),
..Default::default()
},
Case {
strides: (0, 0),
expected: Err(OpError::InvalidValue("Strides must be > 0")),
..Default::default()
},
Case {
dilations: (0, 0),
expected: Err(OpError::InvalidValue("Dilations must be > 0")),
..Default::default()
},
Case {
kernel_size: (0, 0),
expected: Err(OpError::InvalidValue("Kernel size must be > 0")),
..Default::default()
},
Case {
padding: [0, 0].into(),
expected: Err(OpError::InvalidValue("Expected 4 padding values")),
..Default::default()
},
Case {
in_size: (4, 4),
dilations: (2, 2),
expected: Err(OpError::InvalidValue("Input too small for kernel size")),
..Default::default()
},
];
cases.test_each(|case| {
let Case {
in_size,
kernel_size,
dilations,
strides,
padding,
round_mode,
expected,
} = case;
assert_eq!(
&calc_output_size_and_padding(
*in_size,
*kernel_size,
*strides,
padding.clone(),
Some(*dilations),
*round_mode,
),
expected
);
})
}
}