use crate::internal::*;
use crate::ops::cnn::PaddingSpec;
use crate::ops::nn::{DataFormat, DataShape};
use ndarray::prelude::*;
use super::PatchAxis;
use std::fmt::Debug;
use std::ops::Range;
use tract_itertools::{izip, Itertools};
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct PatchSpec {
pub input_shape: TVec<usize>,
pub input_inner_stride: usize,
pub output_inner_stride: usize,
pub kernel_shape: TVec<usize>,
pub strides: TVec<usize>,
pub dilations: TVec<usize>,
pub padding: PaddingSpec,
}
impl Debug for PatchSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"input: {} kernel: {} strides: {} dil: {} pad: {:?}",
self.input_shape.iter().join(","),
self.kernel_shape.iter().join(","),
self.strides.iter().join(","),
self.dilations.iter().join(","),
self.padding
)
}
}
impl PatchSpec {
pub fn for_full_shape(
data_format: DataFormat,
input_full_shape: &[usize],
) -> TractResult<PatchSpec> {
let shape = data_format.shape(input_full_shape.into())?;
Ok(Self::for_data_shape(shape))
}
pub fn for_data_shape(data_shape: DataShape) -> PatchSpec {
let input_shape: TVec<usize> = data_shape.hw_dims().into();
PatchSpec {
kernel_shape: tvec!(1; input_shape.len()),
input_inner_stride: *data_shape.w_stride(),
output_inner_stride: 1,
strides: tvec!(1; input_shape.len()),
dilations: tvec!(1; input_shape.len()),
padding: PaddingSpec::Valid,
input_shape,
}
}
pub fn with_kernel_shape(self, kernel_shape: TVec<usize>) -> PatchSpec {
PatchSpec { kernel_shape, ..self }
}
pub fn with_dilations(self, dilations: TVec<usize>) -> PatchSpec {
PatchSpec { dilations, ..self }
}
pub fn with_strides(self, strides: TVec<usize>) -> PatchSpec {
PatchSpec { strides, ..self }
}
pub fn with_padding(self, padding: PaddingSpec) -> PatchSpec {
PatchSpec { padding, ..self }
}
pub fn with_output_inner_stride(self, output_inner_stride: usize) -> PatchSpec {
PatchSpec { output_inner_stride, ..self }
}
pub fn into_patch(self) -> Patch {
let dims = self.padding.compute(
&self.input_shape,
&self.kernel_shape,
&self.dilations,
&self.strides,
);
let output: TVec<usize> = dims.iter().map(|d| d.convoluted).collect();
let pad_before: TVec<usize> = dims.iter().map(|d| d.pad_before).collect();
let pad_after: TVec<usize> = dims.iter().map(|d| d.pad_after).collect();
let data_field: Vec<isize> = ::ndarray::indices(&*self.kernel_shape)
.into_iter()
.flat_map(|coords| {
#[allow(clippy::unnecessary_to_owned)] coords
.slice()
.to_vec()
.into_iter()
.enumerate()
.map(|(ix, c)| (c * self.dilations[ix]) as isize - pad_before[ix] as isize)
})
.collect();
let data_field = Array2::from_shape_vec(
(self.kernel_shape.iter().cloned().product(), self.kernel_shape.len()),
data_field,
)
.unwrap();
let data_field_min_max: TVec<_> = data_field
.columns()
.into_iter()
.map(|col| (col.iter().min().cloned().unwrap(), col.iter().max().cloned().unwrap()))
.collect();
fn strides(shape: &[usize], inner: usize) -> TVec<isize> {
let mut strides: TVec<isize> = tvec![inner as isize];
for dim in shape.iter().skip(1).rev() {
let previous = *strides.last().unwrap();
strides.push(*dim as isize * previous);
}
strides.reverse();
strides
}
let input_storage_strides = strides(&self.input_shape, self.input_inner_stride);
let output_storage_strides = strides(&output, self.output_inner_stride);
let standard_layout_data_field: Vec<isize> = data_field
.outer_iter()
.map(|coords| izip!(coords, &input_storage_strides).map(|(a, b)| a * b).sum::<isize>())
.collect();
let regions: TVec<TVec<_>> = dims
.iter()
.enumerate()
.map(|(ix, d)| {
PatchAxis {
input_dim: self.input_shape[ix],
kernel_dim: self.kernel_shape[ix],
pad_before: d.pad_before,
pad_after: d.pad_after,
output_dim: d.convoluted,
stride: self.strides[ix],
dilation: self.dilations[ix],
}
.regions()
})
.collect::<TVec<_>>();
let zone_strides = strides(®ions.iter().map(|d| d.len()).collect::<TVec<_>>(), 1);
let zones: Vec<Zone> = regions
.iter()
.multi_cartesian_product()
.map(|regions| Zone {
input_zone_offset: 0,
output_ranges: regions.iter().map(|reg| reg.range.clone()).collect(),
output_shape: regions.iter().map(|reg| reg.range.end - reg.range.start).collect(),
output_zone_offset: izip!(®ions, &output_storage_strides)
.map(|(reg, &stride)| reg.range.start as isize * stride)
.sum::<isize>(),
valid: regions.iter().all(|reg| reg.mask.is_none()),
values_offsets: izip!(
0..,
ndarray::indices(&*self.kernel_shape),
&standard_layout_data_field
)
.filter(|(_ix, coords, _offset)| {
izip!(coords.slice(), ®ions)
.all(|(&x, axis)| !axis.mask.as_ref().map(|mask| mask[x]).unwrap_or(false))
})
.map(|(ix, _coords, &window_offset)| (ix, window_offset))
.collect(),
})
.collect();
let valid_zone = zones.iter().position(|z| z.valid);
let mut valid_output_zone = tvec!();
let mut invalid_output_zones = tvec!();
for ix in 0..self.input_shape.len() {
let min_max = data_field_min_max[ix];
let min = (-min_max.0 as usize).divceil(self.strides[ix]);
let max =
(self.input_shape[ix].saturating_sub(min_max.1 as usize)).divceil(self.strides[ix]);
if min != 0 {
let mut invalid = valid_output_zone.clone();
invalid.push(0..min);
while invalid.len() < output.len() {
invalid.push(0..output[invalid.len()])
}
invalid_output_zones.push(invalid);
}
if max < output[ix] {
let mut invalid = valid_output_zone.clone();
invalid.push(max..output[ix]);
while invalid.len() < output.len() {
invalid.push(0..output[invalid.len()])
}
invalid_output_zones.push(invalid);
}
valid_output_zone.push(min..max)
}
let op_strides_times_input_storage_strides =
izip!(&self.strides, &input_storage_strides).map(|(a, b)| (*a as isize * b)).collect();
Patch {
spec: self,
padded: pad_before.iter().any(|&p| p != 0) || pad_after.iter().any(|&p| p != 0),
pad_before,
pad_after,
output_shape: output,
data_field,
data_field_min_max,
standard_layout_data_field,
input_storage_strides,
output_storage_strides,
op_strides_times_input_storage_strides,
valid_output_zone,
invalid_output_zones,
zones,
valid_zone_id: valid_zone,
zone_strides,
}
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Patch {
pub spec: PatchSpec,
pub pad_before: TVec<usize>,
pub pad_after: TVec<usize>,
pub padded: bool,
pub output_shape: TVec<usize>,
pub data_field: Array2<isize>,
pub data_field_min_max: TVec<(isize, isize)>,
pub standard_layout_data_field: Vec<isize>,
pub op_strides_times_input_storage_strides: TVec<isize>,
pub valid_output_zone: TVec<Range<usize>>,
pub invalid_output_zones: TVec<TVec<Range<usize>>>,
pub zones: Vec<Zone>,
pub valid_zone_id: Option<usize>,
pub zone_strides: TVec<isize>,
pub input_storage_strides: TVec<isize>,
pub output_storage_strides: TVec<isize>,
}
impl Debug for Patch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.spec)
}
}
impl Patch {
#[inline]
pub fn rank(&self) -> usize {
self.spec.input_shape.len()
}
unsafe fn is_valid(&self, coords: &[usize]) -> bool {
for ix in 0..self.rank() {
let c = *coords.get_unchecked(ix) as isize;
let strides = *self.spec.strides.get_unchecked(ix) as isize;
let pos = c * strides;
let min_max = self.data_field_min_max.get_unchecked(ix);
if pos + min_max.0 < 0
|| pos + min_max.1 >= *self.spec.input_shape.get_unchecked(ix) as isize
{
return false;
}
}
true
}
pub fn valid_zone(&self) -> Option<&Zone> {
self.valid_zone_id.map(|id| &self.zones[id])
}
#[inline]
pub fn visit_output(&self, mut acceptor: impl FnMut(&Scanner)) {
if self.zones.len() == 0 {
return;
}
let mut scanner = Scanner::new(self);
while !scanner.done() {
acceptor(&scanner);
scanner.next();
}
}
pub fn centers_offsets(&self) -> Vec<isize> {
if self.zones.len() == 0 {
return vec![];
}
let mut scanner = Scanner::new(self);
let len = self.output_shape.iter().cloned().product();
let mut v = Vec::with_capacity(len);
for _ in 0..len {
v.push(scanner.input_center_offset);
scanner.next()
}
v
}
pub fn at<'p>(&'p self, coords: &[usize]) -> PatchIterator<'p> {
self.at_hint(coords, None)
}
pub fn at_hint<'p>(&'p self, coords: &[usize], hint: Option<bool>) -> PatchIterator<'p> {
unsafe {
assert_eq!(coords.len(), self.spec.kernel_shape.len());
let mut center = 0;
for i in 0..self.op_strides_times_input_storage_strides.len() {
center += *self.op_strides_times_input_storage_strides.get_unchecked(i)
* *coords.get_unchecked(i) as isize;
}
let valid = hint.unwrap_or_else(|| !self.padded || self.is_valid(coords));
if valid {
PatchIterator::Fast(FastPatchIterator { patch: self, center, item: 0 })
} else {
let mut input_patch_center: TVec<_> = coords.into();
input_patch_center
.iter_mut()
.zip(self.spec.strides.iter())
.for_each(|(a, &b)| *a *= b);
PatchIterator::Safe(SafePatchIterator {
patch: self,
item: 0,
input_patch_center,
center,
})
}
}
}
pub fn global_offset_for(&self, coords: &[usize], patch_index: usize) -> usize {
assert_eq!(coords.len(), self.spec.kernel_shape.len());
let center = izip!(coords, &self.op_strides_times_input_storage_strides)
.map(|(a, b)| *a as isize * *b)
.sum::<isize>();
(center + self.standard_layout_data_field[patch_index]) as usize
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct Zone {
pub valid: bool,
pub input_zone_offset: isize,
pub output_zone_offset: isize,
pub output_ranges: Box<[Range<usize>]>,
pub output_shape: Box<[usize]>,
pub values_offsets: Box<[(usize, isize)]>,
}
impl Zone {
pub fn contains_output(&self, coords: &[usize]) -> bool {
self.output_ranges.iter().zip(coords).all(|(range, &x)| x >= range.start && x < range.end)
}
#[inline]
pub fn visit_output(&self, patch: &Patch, mut acceptor: impl FnMut(&ZoneScanner)) {
let mut scanner = ZoneScanner::new(self, patch);
while !scanner.done() {
acceptor(&scanner);
scanner.next();
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ZoneScanner<'p> {
pub patch: &'p Patch,
pub zone: &'p Zone,
pub output_offset: isize,
pub output_coords: Box<[usize]>,
pub input_center_offset: isize,
pub inner_loop_axis: usize,
pub inner_loop_len: usize,
pub inner_loop_output_range: Range<usize>,
pub inner_loop_output_stride: isize,
pub inner_loop_input_full_stride: isize,
pub done: bool,
}
impl<'p> ZoneScanner<'p> {
pub fn new(zone: &'p Zone, patch: &'p Patch) -> ZoneScanner<'p> {
let inner_loop_axis =
zone.output_shape.iter().enumerate().max_by_key(|(_, dim)| *dim).unwrap().0;
let inner_loop_output_range = zone.output_ranges[inner_loop_axis].clone();
let inner_loop_output_stride = patch.output_storage_strides[inner_loop_axis];
let inner_loop_input_full_stride =
patch.op_strides_times_input_storage_strides[inner_loop_axis];
let mut scan = ZoneScanner {
patch,
zone,
output_offset: 0,
input_center_offset: 0,
inner_loop_axis,
inner_loop_len: inner_loop_output_range.len(),
inner_loop_output_range,
inner_loop_output_stride,
inner_loop_input_full_stride,
output_coords: zone.output_ranges.iter().map(|r| r.start).collect(),
done: false,
};
scan.refresh_dependent();
scan
}
#[inline]
pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
}
pub unsafe fn next_non_inner_axis(&mut self) {
let rank = self.patch.rank();
let inner_loop_axis = self.inner_loop_axis;
for axis in (0..rank).rev() {
if axis == inner_loop_axis {
continue;
}
*self.output_coords.get_unchecked_mut(axis) += 1;
if *self.output_coords.get_unchecked_mut(axis)
< self.zone.output_ranges.get_unchecked(axis).end
{
self.refresh_dependent();
return;
}
*self.output_coords.get_unchecked_mut(axis) =
self.zone.output_ranges.get_unchecked(axis).start;
}
self.done = true;
}
pub unsafe fn reset(&mut self) {
self.output_offset = 0;
self.input_center_offset = 0;
for ix in 0..self.output_coords.len() {
*self.output_coords.get_unchecked_mut(ix) =
self.zone.output_ranges.get_unchecked(ix).start;
}
self.done = false;
self.refresh_dependent()
}
#[inline(never)]
fn refresh_dependent(&mut self) {
self.input_center_offset = self
.patch
.op_strides_times_input_storage_strides
.iter()
.zip(self.output_coords.iter())
.map(|(a, b)| *a * *b as isize)
.sum();
self.output_offset = self
.patch
.output_storage_strides
.iter()
.zip(self.output_coords.iter())
.map(|(a, b)| a * *b as isize)
.sum();
}
#[inline]
pub fn next(&mut self) {
let inner_loop_axis = self.inner_loop_axis;
unsafe {
*self.output_coords.get_unchecked_mut(inner_loop_axis) += 1;
if *self.output_coords.get_unchecked(inner_loop_axis) < self.inner_loop_output_range.end
{
self.input_center_offset += self.inner_loop_input_full_stride;
self.output_offset += self.inner_loop_output_stride;
} else {
*self.output_coords.get_unchecked_mut(inner_loop_axis) =
self.inner_loop_output_range.start;
self.next_non_inner_axis();
}
}
}
pub fn done(&self) -> bool {
self.done
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Scanner<'p> {
pub patch: &'p Patch,
pub zone_id: usize,
pub zone_coords: TVec<usize>,
pub zone: &'p Zone,
pub output_offset: isize,
pub output_coords: TVec<usize>,
pub input_coords: TVec<usize>,
pub input_center_offset: isize,
done: bool,
}
impl<'p> Scanner<'p> {
fn new(patch: &'p Patch) -> Scanner<'p> {
let rank = patch.rank();
let zone = &patch.zones[0];
Scanner {
patch,
zone_coords: tvec!(0; rank),
zone,
zone_id: 0,
output_offset: 0,
input_center_offset: 0,
input_coords: tvec!(0; rank),
output_coords: tvec!(0; rank),
done: false,
}
}
#[inline]
pub fn valid_count(&self) -> usize {
self.zone.values_offsets.len()
}
#[inline]
pub fn valid_offsets(&self) -> impl Iterator<Item = isize> + '_ {
self.zone.values_offsets.iter().map(move |pair| pair.1 + self.input_center_offset)
}
#[inline]
pub fn valid_offsets_ker_in(&self) -> impl Iterator<Item = (usize, isize)> + '_ {
self.zone.values_offsets.iter().map(move |pair| (pair.0, pair.1 + self.input_center_offset))
}
#[inline]
pub fn next(&mut self) {
let rank = self.patch.rank();
let inner_dim = rank - 1;
unsafe {
*self.output_coords.get_unchecked_mut(inner_dim) += 1;
*self.input_coords.get_unchecked_mut(inner_dim) +=
*self.patch.spec.strides.get_unchecked(inner_dim);
self.output_offset += self.patch.spec.output_inner_stride as isize;
self.input_center_offset +=
self.patch.op_strides_times_input_storage_strides.get_unchecked(inner_dim);
if *self.output_coords.get_unchecked(inner_dim)
< self.zone.output_ranges.get_unchecked(inner_dim).end
{
return;
}
if self.output_coords.get_unchecked(inner_dim)
< self.patch.output_shape.get_unchecked(inner_dim)
{
self.zone_id += 1;
*self.zone_coords.get_unchecked_mut(inner_dim) += 1;
self.zone = self.patch.zones.get_unchecked(self.zone_id);
} else {
for axis in (0..rank - 1).rev() {
*self.output_coords.get_unchecked_mut(axis + 1) = 0;
*self.input_coords.get_unchecked_mut(axis + 1) = 0;
*self.output_coords.get_unchecked_mut(axis) += 1;
*self.input_coords.get_unchecked_mut(axis) +=
self.patch.spec.strides.get_unchecked(axis);
*self.zone_coords.get_unchecked_mut(axis + 1) = 0;
if *self.output_coords.get_unchecked(axis)
== self.zone.output_ranges.get_unchecked(axis).end
{
*self.zone_coords.get_unchecked_mut(axis) += 1;
}
if *self.output_coords.get_unchecked(axis)
< *self.patch.output_shape.get_unchecked(axis)
{
break;
}
}
if self.output_coords.get_unchecked(0) == self.patch.output_shape.get_unchecked(0) {
self.done = true;
return;
}
self.zone_id = 0;
self.input_center_offset = 0;
for i in 0..rank {
self.zone_id += *self.zone_coords.get_unchecked(i)
* *self.patch.zone_strides.get_unchecked(i) as usize;
self.input_center_offset += *self.input_coords.get_unchecked(i) as isize
* *self.patch.input_storage_strides.get_unchecked(i);
}
self.zone = self.patch.zones.get_unchecked(self.zone_id);
}
}
}
pub fn done(&self) -> bool {
self.done
}
}
#[derive(Debug)]
pub enum PatchIterator<'p> {
Fast(FastPatchIterator<'p>),
Safe(SafePatchIterator<'p>),
}
impl<'p> Iterator for PatchIterator<'p> {
type Item = Option<isize>;
#[inline(always)]
fn next(&mut self) -> Option<Option<isize>> {
match self {
PatchIterator::Fast(ref mut it) => it.next(),
PatchIterator::Safe(ref mut it) => it.next(),
}
}
}
#[derive(Debug)]
pub struct FastPatchIterator<'p> {
patch: &'p Patch,
center: isize,
item: usize,
}
impl<'p> Iterator for FastPatchIterator<'p> {
type Item = Option<isize>;
#[inline(always)]
fn next(&mut self) -> Option<Option<isize>> {
if self.item == self.patch.standard_layout_data_field.len() {
return None;
}
unsafe {
let position =
self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
self.item += 1;
Some(Some(position))
}
}
}
#[derive(Debug)]
pub struct SafePatchIterator<'p> {
patch: &'p Patch,
item: usize,
input_patch_center: TVec<usize>,
center: isize,
}
impl<'p> Iterator for SafePatchIterator<'p> {
type Item = Option<isize>;
fn next(&mut self) -> Option<Option<isize>> {
unsafe {
if self.item == self.patch.standard_layout_data_field.len() {
return None;
}
let input_shape = &self.patch.spec.input_shape;
let img_offset = self.patch.data_field.as_ptr().add(self.item * input_shape.len());
for ix in 0..input_shape.len() {
let pos = *self.input_patch_center.get_unchecked(ix) as isize + *img_offset.add(ix);
if pos < 0 || pos as usize >= *input_shape.get_unchecked(ix) {
self.item += 1;
return Some(None);
}
}
let position =
self.center + self.patch.standard_layout_data_field.get_unchecked(self.item);
self.item += 1;
Some(Some(position))
}
}
}
#[cfg(test)]
pub mod test {
use super::*;
use crate::ops::nn::DataFormat::*;
use proptest::prelude::*;
use proptest::*;
fn compute_output_spatial_dim(
input: usize,
dilation: usize,
kdim: usize,
pad_before: usize,
bad_after: usize,
stride: usize,
) -> usize {
let patch = PatchSpec::for_full_shape(NCHW, &[1, 1, input])
.unwrap()
.with_dilations(tvec!(dilation))
.with_kernel_shape(tvec!(kdim))
.with_padding(PaddingSpec::Explicit(tvec![pad_before], tvec![bad_after], true))
.with_strides(tvec![stride])
.into_patch();
patch.output_shape[0]
}
#[test]
fn basic() {
assert_eq!(compute_output_spatial_dim(5, 1, 3, 0, 0, 1), 3);
}
#[test]
fn strides() {
assert_eq!(compute_output_spatial_dim(7, 1, 3, 0, 0, 2), 3);
}
#[test]
fn padding() {
assert_eq!(compute_output_spatial_dim(5, 1, 3, 1, 1, 1), 5);
}
#[test]
fn strides_and_padding() {
assert_eq!(compute_output_spatial_dim(7, 1, 3, 1, 1, 2), 4);
}
fn field(kdim: &[usize], dilations: &[usize]) -> Array2<isize> {
let patch =
PatchSpec::for_data_shape(NCHW.from_n_c_hw(1, 1, tvec![10; kdim.len()]).unwrap())
.with_dilations(dilations.into())
.with_kernel_shape(kdim.into())
.with_strides(tvec![1; kdim.len()])
.into_patch();
patch.data_field
}
#[test]
fn test_field() {
assert_eq!(field(&[3], &[1]), arr2(&[[0], [1], [2]]));
assert_eq!(field(&[3], &[2]), arr2(&[[0], [2], [4]]));
assert_eq!(field(&[2, 2], &[1, 1]), arr2(&[[0, 0], [0, 1], [1, 0], [1, 1]]));
assert_eq!(field(&[2, 2], &[2, 1]), arr2(&[[0, 0], [0, 1], [2, 0], [2, 1]]));
}
pub fn tensor(shape: &[usize]) -> BoxedStrategy<Tensor> {
let len = shape.iter().product::<usize>();
let shape = shape.to_vec();
proptest::collection::vec(any::<i8>().prop_map(|i| i as f32), len..=len)
.prop_map(move |vec| ArrayD::from_shape_vec(shape.clone(), vec).unwrap().into_tensor())
.boxed()
}
#[derive(Debug)]
struct Problem {
patch: Patch,
input: Tensor,
data_format: DataFormat,
}
impl Arbitrary for Problem {
type Parameters = ();
type Strategy = BoxedStrategy<Problem>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
(
prop_oneof!(Just(NCHW), Just(NHWC)),
(1usize..3, 1usize..3),
1usize..3,
(1usize..3, 1usize..3),
prop_oneof![
Just(PaddingSpec::Valid),
Just(PaddingSpec::SameLower),
Just(PaddingSpec::SameUpper)
],
(1usize..4, 1usize..4),
)
.prop_flat_map(|p| {
let dil = p.1;
let ks = p.3;
let strides = p.5;
let min_size: (usize, usize) = (1 + (ks.0 - 1) * dil.0, 1 + (ks.1 - 1) * dil.1);
(
Just(p),
(min_size.0..min_size.0 + strides.0 * 3),
(min_size.1..min_size.1 + strides.1 * 3),
)
})
.prop_flat_map(|(p, h, w)| {
let input_shape = p.0.from_n_c_hw(1, p.2, [h, w]).unwrap();
let input = tensor(&input_shape.shape);
(Just(p), input)
})
.prop_map(|((fmt, dil, c, ks, pad, strides), input)| {
let output_inner_stride = if fmt.c_is_last() { c } else { 1 };
Problem {
patch: PatchSpec::for_full_shape(fmt, input.shape())
.unwrap()
.with_dilations(tvec!(dil.0, dil.1))
.with_kernel_shape(tvec!(ks.0, ks.1))
.with_padding(pad)
.with_strides(tvec![strides.0, strides.1])
.with_output_inner_stride(output_inner_stride)
.into_patch(),
input,
data_format: fmt,
}
})
.boxed()
}
}
impl Problem {
fn input_shape(&self) -> DataShape {
self.data_format.shape(self.input.shape().into()).unwrap()
}
fn output_shape(&self) -> DataShape {
self.data_format
.from_n_c_hw(
self.input_shape().n().cloned().unwrap_or(1),
*self.input_shape().c(),
&*self.patch.output_shape,
)
.unwrap()
}
fn reference_sumpool(&self) -> Tensor {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
for geo_out in tract_ndarray::indices(output_shape.hw_dims()) {
for geo_ker in tract_ndarray::indices(&*self.patch.spec.kernel_shape) {
let geo_in: TVec<isize> = izip!(
geo_out.slice(),
geo_ker.slice(),
&self.patch.spec.strides,
&self.patch.spec.dilations,
&self.patch.pad_before
)
.map(|(o, k, s, d, p)| (o * s + k * d) as isize - *p as isize)
.collect();
if izip!(&geo_in, input_shape.hw_dims())
.any(|(g, i)| *g >= *i as isize || *g < 0)
{
continue;
}
let geo_in: TVec<usize> = geo_in.into_iter().map(|x| x as usize).collect();
for c in 0..*output_shape.c() {
let ocoords = self.data_format.from_n_c_hw(0, c, geo_out.slice()).unwrap();
let icoords = self.data_format.from_n_c_hw(0, c, &geo_in).unwrap();
output.to_array_view_mut::<f32>().unwrap()[&*ocoords.shape] +=
self.input.to_array_view::<f32>().unwrap()[&*icoords.shape];
}
}
}
output
}
fn check_visitor(&self) {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
self.patch.visit_output(|visitor| {
for (_k, offset_in) in visitor.valid_offsets_ker_in() {
for c in 0..*output_shape.c() {
output.as_slice_mut::<f32>().unwrap()
[visitor.output_offset as usize + c * output_shape.c_stride()] +=
self.input.as_slice::<f32>().unwrap()
[offset_in as usize + c * input_shape.c_stride()];
}
}
});
assert_eq!(output, self.reference_sumpool());
}
fn check_zone_visitor(&self) {
let input_shape = self.input_shape();
let output_shape = self.output_shape();
let mut output = Tensor::zero::<f32>(&output_shape.shape).unwrap();
for zone in &self.patch.zones {
zone.visit_output(&self.patch, |visitor| {
for (_k, offset_in) in visitor.valid_offsets_ker_in() {
for c in 0..*output_shape.c() {
output.as_slice_mut::<f32>().unwrap()
[visitor.output_offset as usize + c * output_shape.c_stride()] +=
self.input.as_slice::<f32>().unwrap()
[offset_in as usize + c * input_shape.c_stride()];
}
}
});
}
assert_eq!(output, self.reference_sumpool());
}
fn check_zoning(&self) {
fn in_zone(full_coords: &[usize], h_axis: usize, zone: &[Range<usize>]) -> bool {
for a in 0..zone.len() {
if full_coords[h_axis + a] < zone[a].start
|| full_coords[h_axis + a] >= zone[a].end
{
return false;
}
}
true
}
let valid_zone = &self.patch.valid_output_zone;
let invalid_zones = &self.patch.invalid_output_zones;
let output_full_shape = self.output_shape();
let h_axis = self.input_shape().h_axis();
for coords in ndarray::indices(&*output_full_shape.shape) {
let inside_valid = in_zone(coords.slice(), h_axis, valid_zone);
let invalid_count =
invalid_zones.iter().filter(|z| in_zone(coords.slice(), h_axis, z)).count();
unsafe {
assert_eq!(
inside_valid,
self.patch.is_valid(&coords.slice()[self.input_shape().hw_axes()]),
"coords {:?}, valid_zone: {:?} inside_valid: {:?}",
coords.slice(),
valid_zone,
inside_valid
);
}
if inside_valid {
assert_eq!(invalid_count, 0);
} else {
assert_eq!(
invalid_count,
1,
"coords {:?}, valid_zone: {:?} inside_valid: {:?} invalid_zones: {:?}",
coords.slice(),
valid_zone,
inside_valid,
invalid_zones
);
}
}
}
}
proptest! {
#[test]
fn test_visitor(pb in any::<Problem>()) {
pb.check_visitor();
}
#[test]
fn test_zone_visitor(pb in any::<Problem>()) {
pb.check_zone_visitor();
}
#[test]
fn test_zoning(pb in any::<Problem>()) {
pb.check_zoning();
}
}
#[test]
fn test_visitor_1() {
let input_shape = NCHW.from_n_c_hw(1, 1, [2, 2]).unwrap();
let input = Tensor::zero::<f32>(&input_shape.shape).unwrap();
let patch = PatchSpec::for_data_shape(input_shape.clone())
.with_kernel_shape(tvec![2, 1])
.with_padding(PaddingSpec::SameLower)
.with_strides(tvec![1, 2])
.into_patch();
Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
}
#[test]
fn test_visitor_2() {
let input_shape = NCHW.from_n_c_hw(1, 2, [1, 1]).unwrap();
let input = tensor4(&[[[[0.]], [[1f32]]]]);
assert_eq!(input.shape(), &*input_shape.shape);
let patch =
PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
}
#[test]
fn test_visitor_3() {
let input_shape = NHWC.from_n_c_hw(1, 2, [2, 1]).unwrap();
let input = tensor4(&[[[[0., 0.]], [[1., 0f32]]]]);
assert_eq!(input.shape(), &*input_shape.shape);
let patch =
PatchSpec::for_data_shape(input_shape.clone()).with_output_inner_stride(2).into_patch();
Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
}
#[test]
fn test_visitor_4() {
let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
let input = tensor4(&[[[[0., 1f32]]]]);
assert_eq!(input.shape(), &*input_shape.shape);
let patch = PatchSpec::for_data_shape(input_shape.clone())
.with_kernel_shape(tvec!(1, 2))
.with_output_inner_stride(1)
.with_padding(PaddingSpec::SameLower)
.into_patch();
Problem { patch, input, data_format: input_shape.fmt }.check_visitor();
}
#[test]
fn test_zone_visitor_1() {
let input_shape = NCHW.from_n_c_hw(1, 1, [2, 1]).unwrap();
let input = tensor4(&[[[[0.], [1f32]]]]);
assert_eq!(input.shape(), &*input_shape.shape);
let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
}
#[test]
fn test_zone_visitor_2() {
let input_shape = NCHW.from_n_c_hw(1, 1, [1, 2]).unwrap();
let input = tensor4(&[[[[0., 1f32]]]]);
assert_eq!(input.shape(), &*input_shape.shape);
let patch = PatchSpec::for_data_shape(input_shape.clone()).into_patch();
Problem { patch, input, data_format: input_shape.fmt }.check_zone_visitor();
}
}