use super::{DataFormat, DataShape, PaddingSpec};
use crate::ops::prelude::*;
use ndarray::prelude::*;
#[derive(Debug, Clone, PartialEq)]
pub struct Patch {
pub dilations: TVec<usize>,
pub kernel_spatial_shape: TVec<usize>,
pub pad_before: TVec<usize>,
pub pad_after: TVec<usize>,
pub padded: bool,
pub kernel_strides: TVec<usize>,
pub input_shape: DataShape<usize, TVec<usize>>,
pub output_spatial_shape: TVec<usize>,
pub data_field: Array2<isize>,
pub data_field_min_max: TVec<(isize, isize)>,
pub standard_layout_data_field: Vec<isize>,
}
impl Patch {
pub fn new(
data_fmt: DataFormat,
dilations: TVec<usize>,
kernel_spatial_shape: TVec<usize>,
padding: &PaddingSpec,
kernel_strides: TVec<usize>,
input_full_shape: TVec<usize>,
) -> Patch {
use crate::ops::nn::padding::ComputedPaddedDim;
let input_shape = data_fmt.shape(input_full_shape);
let ComputedPaddedDim {
pad_after,
pad_before,
output,
} = padding.compute(
input_shape.hw_dims(),
&kernel_spatial_shape,
&*dilations,
&*kernel_strides,
);
let data_field: Vec<isize> = ::ndarray::indices(&*kernel_spatial_shape)
.into_iter()
.flat_map(|coords| {
coords
.slice()
.to_vec()
.into_iter()
.enumerate()
.map(|(ix, c)| (c * dilations[ix]) as isize - pad_before[ix] as isize)
})
.collect();
let data_field = Array2::from_shape_vec(
(
kernel_spatial_shape.iter().cloned().product(),
kernel_spatial_shape.len(),
),
data_field,
)
.unwrap();
let data_field_min_max = data_field.gencolumns().into_iter().map(|col|
(col.iter().min().cloned().unwrap(), col.iter().max().cloned().unwrap())
).collect();
let mut input_layout_strides: Vec<usize> = vec![1];
for dim in input_shape.shape.iter().skip(1).rev() {
let previous = input_layout_strides.last().cloned().unwrap_or(1);
input_layout_strides.push(dim * previous);
}
input_layout_strides.reverse();
let standard_layout_data_field: Vec<isize> = data_field
.outer_iter()
.map(|coords| {
coords
.iter()
.zip(input_layout_strides.iter().skip(input_shape.h_axis()))
.map(|(&a, &b)| (a as isize * b as isize))
.sum()
})
.collect();
Patch {
dilations,
kernel_spatial_shape,
padded: pad_before.iter().any(|&p| p != 0) || pad_after.iter().any(|&p| p != 0),
pad_before,
pad_after,
kernel_strides,
input_shape,
output_spatial_shape: output,
data_field,
data_field_min_max,
standard_layout_data_field,
}
}
pub fn output_full_shape(&self, channels: usize) -> TVec<usize> {
let mut v = self.input_shape.shape.clone();
v[self.input_shape.c_axis()] = channels;
v[self.input_shape.hw_axes()].copy_from_slice(&self.output_spatial_shape);
v
}
pub fn wrap<'i, 'p, T: Datum>(
&'p self,
input: &'i ArrayViewD<'i, T>,
) -> PatchVisitor<'i, 'p, T> {
let valid = !self.padded;
let mut fast_strides = input.strides().to_vec();
fast_strides[self.input_shape.hw_axes()]
.iter_mut()
.zip(self.kernel_strides.iter())
.for_each(|(a, &b)| *a *= b as isize);
PatchVisitor {
patch: &self,
input,
valid,
fast_strides,
}
}
}
#[derive(Debug)]
pub struct PatchVisitor<'i, 'p, T: Datum> {
patch: &'p Patch,
input: &'i ArrayViewD<'i, T>,
valid: bool,
fast_strides: Vec<isize>,
}
impl<'i, 'p, T: Datum> PatchVisitor<'i, 'p, T> {
pub fn at<'v>(&'p self, coords: &[usize]) -> PatchIterator<'i, 'p, 'v, T>
where
'i: 'v,
'p: 'v,
{
if self.valid || coords[self.patch.input_shape.hw_axes()].iter().enumerate().all(|(ix,&c)| {
(c * self.patch.kernel_strides[ix]) as isize + self.patch.data_field_min_max[ix].0 >= 0 &&
(c * self.patch.kernel_strides[ix]) as isize + self.patch.data_field_min_max[ix].1 < self.patch.input_shape.hw_dims()[ix] as isize
}) {
let center = coords
.iter()
.zip(self.fast_strides.iter())
.map(|(&a, &b)| b * a as isize)
.sum();
PatchIterator::Fast(FastPatchIterator {
visitor: &self,
ptr: self.input.as_ptr(),
center,
item: 0,
})
} else {
let mut input_patch_center = coords.to_vec();
input_patch_center[self.patch.input_shape.hw_axes()]
.iter_mut()
.zip(self.patch.kernel_strides.iter())
.for_each(|(a, &b)| *a *= b as usize);
let input_patch_current = vec![0; coords.len()];
PatchIterator::Safe(SafePatchIterator {
visitor: self,
item: 0,
input_patch_center,
input_patch_current,
})
}
}
pub fn global_offset_for(&self, coords: &[usize], patch_index: usize) -> usize {
let center = coords
.iter()
.zip(self.fast_strides.iter())
.map(|(&a, &b)| b * a as isize)
.sum::<isize>();
(center + self.patch.standard_layout_data_field[patch_index]) as usize
}
}
pub enum PatchIterator<'i: 'v, 'p: 'v, 'v, T: Datum> {
Fast(FastPatchIterator<'i, 'p, 'v, T>),
Safe(SafePatchIterator<'i, 'p, 'v, T>),
}
impl<'i: 'v, 'p: 'v, 'v, T: Datum + PartialEq> Iterator for PatchIterator<'p, 'i, 'v, T> {
type Item = Option<T>;
#[inline(always)]
fn next(&mut self) -> Option<Option<T>> {
match self {
&mut PatchIterator::Fast(ref mut it) => it.next(),
&mut PatchIterator::Safe(ref mut it) => it.next(),
}
}
}
pub struct FastPatchIterator<'i: 'v, 'p: 'v, 'v, T: Datum> {
visitor: &'v PatchVisitor<'i, 'p, T>,
ptr: *const T,
center: isize,
item: usize,
}
impl<'i: 'v, 'p: 'v, 'v, T: Datum + PartialEq> Iterator for FastPatchIterator<'i, 'p, 'v, T> {
type Item = Option<T>;
#[inline(always)]
fn next(&mut self) -> Option<Option<T>> {
if self.item == self.visitor.patch.standard_layout_data_field.len() {
return None;
}
unsafe {
let position = self.center
+ self
.visitor
.patch
.standard_layout_data_field
.get_unchecked(self.item);
self.item += 1;
Some(Some(*(self.ptr.offset(position))))
}
}
}
pub struct SafePatchIterator<'i: 'v, 'p: 'v, 'v, T: Datum> {
visitor: &'v PatchVisitor<'i, 'p, T>,
item: usize,
input_patch_center: Vec<usize>,
input_patch_current: Vec<usize>,
}
impl<'i: 'v, 'p: 'v, 'v, T: Datum + PartialEq> Iterator for SafePatchIterator<'i, 'p, 'v, T> {
type Item = Option<T>;
#[inline(never)]
fn next(&mut self) -> Option<Option<T>> {
if self.item == self.visitor.patch.data_field.rows() {
return None;
}
let img_offset = self.visitor.patch.data_field.row(self.item);
self.item += 1;
(&mut *self.input_patch_current).copy_from_slice(&self.input_patch_center);
self.input_patch_current[self.visitor.patch.input_shape.hw_axes()]
.iter_mut()
.zip(img_offset.iter())
.for_each(|(x, &i)| *x = (*x as isize + i as isize) as usize);
Some(self.visitor.input.get(&*self.input_patch_current).cloned())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::ops::nn::DataFormat::NCHW;
fn compute_output_spatial_dim(
input: usize,
dilation: usize,
kdim: usize,
pad_before: usize,
bad_after: usize,
stride: usize,
) -> usize {
let patch = Patch::new(
DataFormat::NCHW,
tvec![dilation],
tvec![kdim],
&PaddingSpec::Explicit(tvec![pad_before], tvec![bad_after]),
tvec![stride],
tvec![1, 1, input],
);
patch.output_spatial_shape[0]
}
#[test]
fn basic() {
assert_eq!(compute_output_spatial_dim(5, 1, 3, 0, 0, 1), 3);
}
#[test]
fn strides() {
assert_eq!(compute_output_spatial_dim(7, 1, 3, 0, 0, 2), 3);
}
#[test]
fn padding() {
assert_eq!(compute_output_spatial_dim(5, 1, 3, 1, 1, 1), 5);
}
#[test]
fn strides_and_padding() {
assert_eq!(compute_output_spatial_dim(7, 1, 3, 1, 1, 2), 4);
}
fn field(kdim: &[usize], dilations: &[usize]) -> Array2<isize> {
let patch = Patch::new(
NCHW,
dilations.into(),
kdim.into(),
&PaddingSpec::Explicit(tvec![0; kdim.len()], tvec![0; kdim.len()]),
tvec![1; kdim.len()],
tvec![10; kdim.len() + 2],
);
patch.data_field
}
#[test]
fn test_field() {
assert_eq!(field(&[3], &[1]), arr2(&[[0], [1], [2]]));
assert_eq!(field(&[3], &[2]), arr2(&[[0], [2], [4]]));
assert_eq!(
field(&[2, 2], &[1, 1]),
arr2(&[[0, 0], [0, 1], [1, 0], [1, 1]])
);
assert_eq!(
field(&[2, 2], &[2, 1]),
arr2(&[[0, 0], [0, 1], [2, 0], [2, 1]])
);
}
}