use ndarray::{ArrayBase, ArrayD, ArrayViewMutD, AsArray, Axis, Dimension, Slice, ViewRepr};
use crate::error::ImgalError;
use crate::traits::numeric::AsNumeric;
pub fn constant_pad<'a, T, A, D>(
data: A,
value: T,
pad_config: &[usize],
direction: Option<u8>,
) -> Result<ArrayD<T>, ImgalError>
where
A: AsArray<'a, T, D>,
D: Dimension,
T: 'a + AsNumeric,
{
let view: ArrayBase<ViewRepr<&'a T>, D> = data.into();
let src_shape = view.shape().to_vec();
let sl = src_shape.len();
if sl != pad_config.len() {
return Err(ImgalError::MismatchedArrayLengths {
a_arr_name: "shape",
a_arr_len: sl,
b_arr_name: "pad_config",
b_arr_len: pad_config.len(),
});
}
if pad_config.iter().all(|&v| v == 0) {
return Ok(view.into_dyn().to_owned());
}
let direction = direction.unwrap_or(2);
if direction > 2 {
return Err(ImgalError::InvalidParameterValueGreater {
param_name: "direction",
value: 2,
});
}
let pad_shape: Vec<usize>;
match direction {
0 | 1 => {
pad_shape = create_pad_shape(&src_shape, pad_config, false);
}
_ => {
pad_shape = create_pad_shape(&src_shape, pad_config, true);
}
}
let mut pad_arr = ArrayD::from_elem(pad_shape, value);
let mut pad_view = pad_arr.view_mut();
slice_pad_view(&mut pad_view, &src_shape, pad_config, direction);
pad_view.assign(&view);
Ok(pad_arr)
}
pub fn reflect_pad<'a, T, A, D>(
data: A,
pad_config: &[usize],
direction: Option<u8>,
) -> Result<ArrayD<T>, ImgalError>
where
A: AsArray<'a, T, D>,
D: Dimension,
T: 'a + AsNumeric,
{
let view: ArrayBase<ViewRepr<&'a T>, D> = data.into();
let src_shape = view.shape().to_vec();
let sl = src_shape.len();
if sl != pad_config.len() {
return Err(ImgalError::MismatchedArrayLengths {
a_arr_name: "shape",
a_arr_len: sl,
b_arr_name: "pad_config",
b_arr_len: pad_config.len(),
});
}
pad_config
.iter()
.zip(src_shape.iter())
.enumerate()
.filter(|&(_, (&p, &s))| p >= s)
.try_for_each(|(i, (&_, &s))| {
return Err(ImgalError::InvalidAxisValueGreaterEqual {
arr_name: "pad_config",
axis_idx: i,
value: s,
});
})?;
if pad_config.iter().all(|&v| v == 0) {
return Ok(view.into_dyn().to_owned());
}
let direction = direction.unwrap_or(2);
if direction > 2 {
return Err(ImgalError::InvalidParameterValueGreater {
param_name: "direction",
value: 2,
});
}
let mut pad_arr = zero_pad(view, pad_config, Some(direction))?;
pad_config
.iter()
.zip(src_shape.iter())
.enumerate()
.filter(|&(_, (&p, &_))| p != 0)
.for_each(|(i, (&p, &s))| {
let pad_view = pad_arr.view_mut();
match direction {
0 => {
let (src_data, mut end_pad) = pad_view.split_at(Axis(i), s);
let mut end_reflect =
src_data.slice_axis(Axis(i), Slice::from((s - p - 1)..(s - 1)));
end_reflect.invert_axis(Axis(i));
end_pad.assign(&end_reflect);
}
1 => {
let (mut start_pad, src_data) = pad_view.split_at(Axis(i), p);
let mut start_reflect = src_data.slice_axis(Axis(i), Slice::from(1..p + 1));
start_reflect.invert_axis(Axis(i));
start_pad.assign(&start_reflect);
}
_ => {
let (mut start_pad, chunk) = pad_view.split_at(Axis(i), p);
let (src_data, mut end_pad) = chunk.split_at(Axis(i), s);
let mut start_reflect = src_data.slice_axis(Axis(i), Slice::from(1..p + 1));
start_reflect.invert_axis(Axis(i));
start_pad.assign(&start_reflect);
let mut end_reflect =
src_data.slice_axis(Axis(i), Slice::from((s - p - 1)..(s - 1)));
end_reflect.invert_axis(Axis(i));
end_pad.assign(&end_reflect);
}
}
});
Ok(pad_arr)
}
pub fn zero_pad<'a, T, A, D>(
data: A,
pad_config: &[usize],
direction: Option<u8>,
) -> Result<ArrayD<T>, ImgalError>
where
A: AsArray<'a, T, D>,
D: Dimension,
T: 'a + AsNumeric,
{
let view: ArrayBase<ViewRepr<&'a T>, D> = data.into();
let src_shape = view.shape().to_vec();
let sl = src_shape.len();
if sl != pad_config.len() {
return Err(ImgalError::MismatchedArrayLengths {
a_arr_name: "shape",
a_arr_len: sl,
b_arr_name: "pad_config",
b_arr_len: pad_config.len(),
});
}
if pad_config.iter().all(|&v| v == 0) {
return Ok(view.into_dyn().to_owned());
}
let direction = direction.unwrap_or(2);
if direction > 2 {
return Err(ImgalError::InvalidParameterValueGreater {
param_name: "direction",
value: 2,
});
}
let pad_shape: Vec<usize>;
match direction {
0 | 1 => {
pad_shape = create_pad_shape(&src_shape, pad_config, false);
}
_ => {
pad_shape = create_pad_shape(&src_shape, pad_config, true);
}
}
let mut pad_arr = ArrayD::<T>::default(pad_shape);
let mut pad_view = pad_arr.view_mut();
slice_pad_view(&mut pad_view, &src_shape, pad_config, direction);
pad_view.assign(&view);
Ok(pad_arr)
}
#[inline]
fn create_pad_shape(shape: &[usize], pad_config: &[usize], symmetric: bool) -> Vec<usize> {
let mut pad_shape = vec![0; shape.len()];
shape
.iter()
.zip(pad_config.iter())
.zip(pad_shape.iter_mut())
.for_each(|((&s, &p), d)| {
if symmetric {
*d = s + 2 * p;
} else {
*d = s + p;
}
});
pad_shape
}
#[inline]
fn slice_pad_view<T>(
view: &mut ArrayViewMutD<T>,
slice_shape: &[usize],
pad_config: &[usize],
direction: u8,
) where
T: AsNumeric,
{
pad_config
.iter()
.zip(slice_shape.iter())
.enumerate()
.filter(|(_, (p, _))| **p != 0)
.for_each(|(i, (&p, &s))| {
let ax_slice: Slice;
match direction {
0 => {
ax_slice = Slice {
start: 0 as isize,
end: Some(s as isize),
step: 1,
}
}
_ => {
ax_slice = Slice {
start: p as isize,
end: Some((p + s) as isize),
step: 1,
}
}
}
view.slice_axis_inplace(Axis(i), ax_slice);
});
}