use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, IntoDimension, Ix1, Ix2, Ix3, Slice};
use std::fmt::Debug;
pub trait PaddingMode: Send + Sync + Clone + Debug {
fn pad_inplace<D: ReflPad + ReplPad, S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
&self,
array: &mut ArrayBase<S, D>,
original: &ArrayBase<T, D>,
padding: &[usize],
);
fn pad<D: ReflPad + ReplPad, E: IntoDimension<Dim = D>>(
&self,
input: &Array<f32, D>,
padding: E,
) -> Array<f32, D>;
}
#[derive(Clone, Debug)]
pub struct Zero;
#[derive(Clone, Debug)]
pub struct Constant {
pub value: f32,
}
impl Constant {
pub fn new(value: f32) -> Self {
Self { value }
}
}
#[derive(Clone, Debug)]
pub struct Reflective;
#[derive(Clone, Debug)]
pub struct Replicative;
impl PaddingMode for Zero {
fn pad_inplace<D: Dimension, S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
&self,
input: &mut ArrayBase<S, D>,
original: &ArrayBase<T, D>,
padding: &[usize],
) {
assert_eq!(
padding.len(),
input.ndim(),
"error: padding length {} doesn't match array dimensions {}",
padding.len(),
input.ndim()
);
constant_pad_inplace(input, original, padding, 0.);
}
fn pad<D: Dimension, E: IntoDimension<Dim = D>>(
&self,
input: &Array<f32, D>,
padding: E,
) -> Array<f32, D> {
constant_pad(input, padding, 0.)
}
}
impl PaddingMode for Constant {
fn pad_inplace<D: Dimension, S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
&self,
input: &mut ArrayBase<S, D>,
original: &ArrayBase<T, D>,
padding: &[usize],
) {
assert_eq!(
padding.len(),
input.ndim(),
"error: padding length {} doesn't match array dimensions {}",
padding.len(),
input.ndim()
);
let value = self.value;
constant_pad_inplace(input, original, padding, value);
}
fn pad<D: Dimension, E: IntoDimension<Dim = D>>(
&self,
input: &Array<f32, D>,
padding: E,
) -> Array<f32, D> {
let value = self.value;
constant_pad(input, padding, value)
}
}
impl PaddingMode for Reflective {
fn pad_inplace<D: ReflPad, S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
&self,
input: &mut ArrayBase<S, D>,
original: &ArrayBase<T, D>,
padding: &[usize],
) {
assert_eq!(
padding.len(),
input.ndim(),
"error: padding length {} doesn't match array dimensions {}",
padding.len(),
input.ndim()
);
D::reflection_pad_inplace(input, original, padding);
}
fn pad<D: ReflPad, E: IntoDimension<Dim = D>>(
&self,
input: &Array<f32, D>,
padding: E,
) -> Array<f32, D> {
D::reflection_pad(input, padding.into_dimension().slice())
}
}
impl PaddingMode for Replicative {
fn pad_inplace<D: ReplPad, S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
&self,
input: &mut ArrayBase<S, D>,
original: &ArrayBase<T, D>,
padding: &[usize],
) {
assert_eq!(
padding.len(),
input.ndim(),
"error: padding length {} doesn't match array dimensions {}",
padding.len(),
input.ndim()
);
D::replication_pad_inplace(input, original, padding);
}
fn pad<D: ReplPad, E: IntoDimension<Dim = D>>(
&self,
input: &Array<f32, D>,
padding: E,
) -> Array<f32, D> {
D::replication_pad(input, padding.into_dimension().slice())
}
}
pub trait ReflPad: Dimension {
fn reflection_pad<S: DataMut<Elem = f32>>(
input: &ArrayBase<S, Self>,
padding: &[usize],
) -> Array<f32, Self>;
fn reflection_pad_inplace<S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
to_pad: &mut ArrayBase<S, Self>,
input: &ArrayBase<T, Self>,
padding: &[usize],
);
}
pub trait ReplPad: Dimension {
fn replication_pad<S: DataMut<Elem = f32>>(
input: &ArrayBase<S, Self>,
padding: &[usize],
) -> Array<f32, Self>;
fn replication_pad_inplace<S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
to_pad: &mut ArrayBase<S, Self>,
input: &ArrayBase<T, Self>,
padding: &[usize],
);
}
fn constant_pad<S, D, E>(input: &ArrayBase<S, D>, padding: E, val: f32) -> Array<f32, D>
where
D: Dimension,
S: DataMut<Elem = f32>,
E: IntoDimension<Dim = D>,
{
let padding_into_dim = padding.into_dimension();
let padded_shape = {
let mut padded_shape = input.raw_dim();
padded_shape
.slice_mut()
.iter_mut()
.zip(padding_into_dim.slice().iter())
.for_each(|(ax_len, pad)| *ax_len += pad * 2);
padded_shape
};
let mut padded = Array::zeros(padded_shape);
constant_pad_inplace(&mut padded, input, padding_into_dim.slice(), val);
padded
}
fn constant_pad_inplace<S, T, D>(
input: &mut ArrayBase<S, D>,
original: &ArrayBase<T, D>,
padding: &[usize],
val: f32,
) where
D: Dimension,
S: DataMut<Elem = f32>,
T: Data<Elem = f32>,
{
input.map_inplace(|el| *el = val);
let mut orig_portion = input.view_mut();
orig_portion.slice_each_axis_inplace(|ax| {
let (ax_index, ax_len) = (ax.axis.index(), original.len_of(ax.axis));
let range = {
if padding[ax_index] != 0 {
padding[ax_index] as isize..-(padding[ax_index] as isize)
} else {
0..ax_len as isize
}
};
Slice::from(range)
});
orig_portion.assign(original);
}
impl ReflPad for Ix1 {
fn reflection_pad<S: DataMut<Elem = f32>>(
input: &ArrayBase<S, Ix1>,
padding: &[usize],
) -> Array<f32, Ix1> {
let out_len = {
let len = input.len();
let pad = padding[0];
len + pad * 2
};
let mut out = Array::<f32, _>::zeros(out_len);
Self::reflection_pad_inplace(&mut out, input, padding);
out
}
fn reflection_pad_inplace<S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
to_pad: &mut ArrayBase<S, Ix1>,
input: &ArrayBase<T, Ix1>,
padding: &[usize],
) {
let mut pos;
let (in_len, out_len, pad) = { (input.len(), to_pad.len(), padding[0]) };
let (in_slice, out_slice) = (input.as_slice().unwrap(), to_pad.as_slice_mut().unwrap());
for (i, out_slice_el) in out_slice.iter_mut().enumerate().take(out_len) {
if i < pad {
pos = pad * 2 - i;
} else if i >= pad && i < in_len + pad {
pos = i;
} else {
pos = (in_len + pad - 1) * 2 - i;
}
pos -= pad;
*out_slice_el = in_slice[pos];
}
}
}
impl ReflPad for Ix2 {
fn reflection_pad<S: DataMut<Elem = f32>>(
input: &ArrayBase<S, Ix2>,
padding: &[usize],
) -> Array<f32, Ix2> {
let (len_x, len_y) = {
let in_sp = input.shape();
(in_sp[0], in_sp[1])
};
let (pad_x, pad_y) = (padding[0], padding[1]);
let (out_len_x, out_len_y) = (len_x + pad_x * 2, len_y + pad_y * 2);
let mut out = Array::<f32, _>::zeros((out_len_x, out_len_y));
Self::reflection_pad_inplace(&mut out, input, padding);
out
}
fn reflection_pad_inplace<S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
to_pad: &mut ArrayBase<S, Ix2>,
input: &ArrayBase<T, Ix2>,
padding: &[usize],
) {
let (mut pos_x, mut pos_y);
let (len_x, len_y) = {
let in_sp = input.shape();
(in_sp[0], in_sp[1])
};
let (pad_x, pad_y) = (padding[0], padding[1]);
let (out_len_x, out_len_y) = (len_x + pad_x * 2, len_y + pad_y * 2);
let (slice_in, slice_out) = { (input.as_slice().unwrap(), to_pad.as_slice_mut().unwrap()) };
for i in 0..out_len_x {
for j in 0..out_len_y {
if j < pad_y {
pos_x = pad_y * 2 - j;
} else if j >= pad_y && j < len_y + pad_y {
pos_x = j;
} else {
pos_x = (len_y + pad_y - 1) * 2 - j;
}
pos_x -= pad_y;
if i < pad_x {
pos_y = pad_x * 2 - i;
} else if i >= pad_x && i < len_x + pad_x {
pos_y = i;
} else {
pos_y = (len_x + pad_x - 1) * 2 - i;
}
pos_y -= pad_x;
slice_out[i * out_len_y + j] = slice_in[pos_y * len_y + pos_x];
}
}
}
}
impl ReflPad for Ix3 {
fn reflection_pad<S: DataMut<Elem = f32>>(
input: &ArrayBase<S, Ix3>,
padding: &[usize],
) -> Array<f32, Ix3> {
let (len_x, len_y, len_z) = {
let in_sp = input.shape();
(in_sp[1], in_sp[2], in_sp[0])
};
let (pad_x, pad_y, pad_z) = (padding[1], padding[2], padding[0]);
let (out_len_x, out_len_y, out_len_z) =
(len_x + pad_x * 2, len_y + pad_y * 2, len_z + pad_z * 2);
let mut out = Array::<f32, _>::zeros((out_len_z, out_len_x, out_len_y));
Self::reflection_pad_inplace(&mut out, input, padding);
out
}
fn reflection_pad_inplace<S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
to_pad: &mut ArrayBase<S, Self>,
input: &ArrayBase<T, Self>,
padding: &[usize],
) {
let (mut pos_x, mut pos_y, mut pos_z);
let (len_x, len_y, len_z) = {
let in_sp = input.shape();
(in_sp[1], in_sp[2], in_sp[0])
};
let (pad_x, pad_y, pad_z) = (padding[1], padding[2], padding[0]);
let (out_len_x, out_len_y, out_len_z) =
(len_x + pad_x * 2, len_y + pad_y * 2, len_z + pad_z * 2);
let (slice_in, slice_out) = { (input.as_slice().unwrap(), to_pad.as_slice_mut().unwrap()) };
for z in 0..out_len_z {
for i in 0..out_len_x {
for j in 0..out_len_y {
if j < pad_y {
pos_x = pad_y * 2 - j;
} else if j >= pad_y && j < len_y + pad_y {
pos_x = j;
} else {
pos_x = (len_y + pad_y - 1) * 2 - j;
}
pos_x -= pad_y;
if i < pad_x {
pos_y = pad_x * 2 - i;
} else if i >= pad_x && i < len_x + pad_x {
pos_y = i;
} else {
pos_y = (len_x + pad_x - 1) * 2 - i;
}
pos_y -= pad_x;
if z < pad_z {
pos_z = pad_z * 2 - z;
} else if z >= pad_z && z < len_z + pad_z {
pos_z = z;
} else {
pos_z = (len_z + pad_z - 1) * 2 - z;
}
pos_z -= pad_z;
slice_out[z * out_len_y * out_len_x + i * out_len_y + j] =
slice_in[pos_z * len_y * len_x + pos_y * len_y + pos_x];
}
}
}
}
}
impl ReplPad for Ix1 {
fn replication_pad<S: Data<Elem = f32>>(
input: &ArrayBase<S, Ix1>,
padding: &[usize],
) -> Array<f32, Ix1> {
let out_len = {
let len = input.len();
let pad = padding[0];
len + pad * 2
};
let mut out = Array::<f32, _>::zeros(out_len);
Self::replication_pad_inplace(&mut out, input, padding);
out
}
fn replication_pad_inplace<S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
to_pad: &mut ArrayBase<S, Self>,
input: &ArrayBase<T, Self>,
padding: &[usize],
) {
let mut pos;
let (in_len, out_len, pad) = (input.len(), to_pad.len(), padding[0]);
let (in_slice, out_slice) = (input.as_slice().unwrap(), to_pad.as_slice_mut().unwrap());
for (j, out_slice_el) in out_slice.iter_mut().enumerate().take(out_len) {
if j < pad {
pos = pad;
} else if j >= pad && j < in_len + pad {
pos = j;
} else {
pos = in_len + pad - 1;
}
pos -= pad;
*out_slice_el = in_slice[pos];
}
}
}
impl ReplPad for Ix2 {
fn replication_pad<S: DataMut<Elem = f32>>(
input: &ArrayBase<S, Ix2>,
padding: &[usize],
) -> Array<f32, Ix2> {
let (len_x, len_y) = {
let in_sp = input.shape();
(in_sp[0], in_sp[1])
};
let (pad_x, pad_y) = (padding[0], padding[1]);
let (out_len_x, out_len_y) = (len_x + pad_x * 2, len_y + pad_y * 2);
let mut out = Array::<f32, _>::zeros((out_len_x, out_len_y));
Self::replication_pad_inplace(&mut out, input, padding);
out
}
fn replication_pad_inplace<S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
to_pad: &mut ArrayBase<S, Self>,
input: &ArrayBase<T, Self>,
padding: &[usize],
) {
let (mut pos_x, mut pos_y);
let (len_x, len_y) = {
let in_sp = input.shape();
(in_sp[0], in_sp[1])
};
let (pad_x, pad_y) = (padding[0], padding[1]);
let (out_len_x, out_len_y) = (len_x + pad_x * 2, len_y + pad_y * 2);
let (slice_in, slice_out) = { (input.as_slice().unwrap(), to_pad.as_slice_mut().unwrap()) };
for i in 0..out_len_x {
for j in 0..out_len_y {
if j < pad_y {
pos_x = pad_y;
} else if j >= pad_y && j < len_y + pad_y {
pos_x = j;
} else {
pos_x = len_y + pad_y - 1;
}
pos_x -= pad_y;
if i < pad_x {
pos_y = pad_x;
} else if i >= pad_x && i < len_x + pad_x {
pos_y = i;
} else {
pos_y = len_x + pad_x - 1;
}
pos_y -= pad_x;
slice_out[i * out_len_y + j] = slice_in[pos_y * len_y + pos_x];
}
}
}
}
impl ReplPad for Ix3 {
fn replication_pad<S: DataMut<Elem = f32>>(
input: &ArrayBase<S, Ix3>,
padding: &[usize],
) -> Array<f32, Ix3> {
let (len_x, len_y, len_z) = {
let in_sp = input.shape();
(in_sp[1], in_sp[2], in_sp[0])
};
let (pad_x, pad_y, pad_z) = (padding[1], padding[2], padding[0]);
let (out_len_x, out_len_y, out_len_z) =
(len_x + pad_x * 2, len_y + pad_y * 2, len_z + pad_z * 2);
let mut out = Array::<f32, _>::zeros((out_len_z, out_len_x, out_len_y));
Self::replication_pad_inplace(&mut out, input, padding);
out
}
fn replication_pad_inplace<S: DataMut<Elem = f32>, T: Data<Elem = f32>>(
to_pad: &mut ArrayBase<S, Self>,
input: &ArrayBase<T, Self>,
padding: &[usize],
) {
let (mut pos_x, mut pos_y, mut pos_z);
let (len_x, len_y, len_z) = {
let in_sp = input.shape();
(in_sp[1], in_sp[2], in_sp[0])
};
let (pad_x, pad_y, pad_z) = (padding[1], padding[2], padding[0]);
let (out_len_x, out_len_y, out_len_z) =
(len_x + pad_x * 2, len_y + pad_y * 2, len_z + pad_z * 2);
let (slice_in, slice_out) = { (input.as_slice().unwrap(), to_pad.as_slice_mut().unwrap()) };
for z in 0..out_len_z {
for i in 0..out_len_x {
for j in 0..out_len_y {
if j < pad_y {
pos_x = pad_y;
} else if j >= pad_y && j < len_y + pad_y {
pos_x = j;
} else {
pos_x = len_y + pad_y - 1;
}
pos_x -= pad_y;
if i < pad_x {
pos_y = pad_x;
} else if i >= pad_x && i < len_x + pad_x {
pos_y = i;
} else {
pos_y = len_x + pad_x - 1;
}
pos_y -= pad_x;
if z < pad_z {
pos_z = pad_z;
} else if z >= pad_z && z < len_z + pad_z {
pos_z = z;
} else {
pos_z = len_z + pad_z - 1;
}
pos_z -= pad_z;
slice_out[z * out_len_y * out_len_x + i * out_len_y + j] =
slice_in[pos_z * len_y * len_x + pos_y * len_y + pos_x];
}
}
}
}
}
#[cfg(test)]
mod test;