pub use self::utils::*;
use num::Zero;
use strum::{AsRefStr, Display, EnumCount, EnumIs, EnumIter, VariantNames};
pub trait PadItem<T> {
type Output;
fn pad(&self, pad: usize) -> Self::Output;
}
#[derive(
AsRefStr,
Clone,
Copy,
Debug,
Default,
Display,
EnumCount,
EnumIs,
EnumIter,
Eq,
Hash,
Ord,
PartialEq,
PartialOrd,
VariantNames,
)]
#[repr(u8)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(rename_all = "snake_case", untagged)
)]
#[strum(serialize_all = "snake_case")]
pub enum PadAction {
Clipping,
Lane,
Reflecting,
#[default]
StopAfterCopy,
Wrapping,
}
#[derive(Clone, Copy, Debug, Display, EnumCount, EnumIs, Eq, Hash, Ord, PartialEq, PartialOrd)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(rename_all = "lowercase", untagged)
)]
pub enum PadMode<T> {
Constant(T),
Edge,
Maximum,
Mean,
Median,
Minimum,
Mode,
Reflect,
Symmetric,
Wrap,
}
impl<T> From<T> for PadMode<T> {
fn from(value: T) -> Self {
PadMode::Constant(value)
}
}
impl<T> PadMode<T> {
pub(crate) fn action(&self) -> PadAction {
match self {
PadMode::Constant(_) => PadAction::StopAfterCopy,
PadMode::Edge => PadAction::Clipping,
PadMode::Maximum => PadAction::Clipping,
PadMode::Mean => PadAction::Clipping,
PadMode::Median => PadAction::Clipping,
PadMode::Minimum => PadAction::Clipping,
PadMode::Mode => PadAction::Clipping,
PadMode::Reflect => PadAction::Reflecting,
PadMode::Symmetric => PadAction::Reflecting,
PadMode::Wrap => PadAction::Wrapping,
}
}
pub fn init(&self) -> T
where
T: Copy + Zero,
{
match *self {
PadMode::Constant(v) => v,
_ => T::zero(),
}
}
}
pub struct Padding<T> {
pub mode: PadMode<T>,
pub pad: usize,
}
mod utils {
use super::{PadAction, PadMode};
use crate::traits::ArrayLike;
use ndarray::{Array, ArrayBase, AxisDescription, Data, DataOwned, Dimension, Slice};
use num::{FromPrimitive, Num};
#[cfg(no_std)]
use alloc::borrow::Cow;
#[cfg(feature = "std")]
use std::borrow::Cow;
fn read_pad(nb_dim: usize, pad: &[[usize; 2]]) -> Cow<[[usize; 2]]> {
if pad.len() == 1 && pad.len() < nb_dim {
Cow::from(vec![pad[0]; nb_dim])
} else if pad.len() == nb_dim {
Cow::from(pad)
} else {
panic!("Inconsistant number of dimensions and pad arrays");
}
}
pub fn pad<A, S, D>(data: &ArrayBase<S, D>, pad: &[[usize; 2]], mode: PadMode<A>) -> Array<A, D>
where
A: Copy + FromPrimitive + Num,
D: Dimension,
S: DataOwned<Elem = A>,
{
let pad = read_pad(data.ndim(), pad);
let mut new_dim = data.raw_dim();
for (ax, (&ax_len, pad)) in data.shape().iter().zip(pad.iter()).enumerate() {
new_dim[ax] = ax_len + pad[0] + pad[1];
}
let mut padded = data.array_like(new_dim, mode.init()).to_owned();
pad_to(data, &pad, mode, &mut padded);
padded
}
pub fn pad_to<A, S, D>(
data: &ArrayBase<S, D>,
pad: &[[usize; 2]],
mode: PadMode<A>,
output: &mut Array<A, D>,
) where
A: Copy + FromPrimitive + Num,
D: Dimension,
S: Data<Elem = A>,
{
let pad = read_pad(data.ndim(), pad);
output
.slice_each_axis_mut(|ad| {
let AxisDescription { axis, len, .. } = ad;
let pad = pad[axis.index()];
Slice::from(pad[0]..len - pad[1])
})
.assign(data);
match mode.action() {
PadAction::StopAfterCopy => { }
_ => unimplemented!(),
}
}
}