use alloc::vec::Vec;
use core::ops::Range;
use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};
use super::Numeric;
pub trait IntoPadding<const D: usize> {
fn into_padding(self) -> [(usize, usize); D];
}
impl<const D: usize, const N: usize> IntoPadding<D> for [(usize, usize); N] {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
N <= D,
"Padding has {} pairs but tensor only has {} dimensions",
N,
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - N;
for (i, pair) in self.into_iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
impl<const D: usize> IntoPadding<D> for (usize, usize, usize, usize) {
fn into_padding(self) -> [(usize, usize); D] {
let (left, right, top, bottom) = self;
let mut result = [(0usize, 0usize); D];
result[D - 2] = (top, bottom);
result[D - 1] = (left, right);
result
}
}
impl<const D: usize> IntoPadding<D> for &[(usize, usize)] {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
self.len() <= D,
"Padding has {} pairs but tensor only has {} dimensions",
self.len(),
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - self.len();
for (i, &pair) in self.iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
impl<const D: usize> IntoPadding<D> for Vec<(usize, usize)> {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
self.len() <= D,
"Padding has {} pairs but tensor only has {} dimensions",
self.len(),
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - self.len();
for (i, pair) in self.into_iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
fn build_slice_ranges<const D: usize>(
dims: [usize; D],
target_dim: usize,
start: usize,
len: usize,
) -> [Range<usize>; D] {
dims.iter()
.enumerate()
.map(|(i, &size)| {
if i == target_dim {
start..start + len
} else {
0..size
}
})
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap()
}
impl<B, const D: usize, K> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
pub fn pad(self, padding: impl IntoPadding<D>, mode: impl Into<PadMode>) -> Self {
let pairs = padding.into_padding();
match mode.into() {
PadMode::Constant(value) => pad_constant(self, &pairs, value),
PadMode::Reflect => pad_reflect(self, &pairs),
PadMode::Edge => pad_edge(self, &pairs),
}
}
}
fn pad_constant<B, const D: usize, K, E>(
tensor: Tensor<B, D, K>,
padding: &[(usize, usize); D],
value: E,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
E: ElementConversion,
{
let mut padded_dims: [usize; D] = tensor.dims();
for (i, &(before, after)) in padding.iter().enumerate() {
padded_dims[i] += before + after;
}
let ranges: [Range<usize>; D] = padded_dims
.iter()
.enumerate()
.map(|(i, &dim)| {
let (before, after) = padding[i];
before..dim - after
})
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap();
let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
padded_tensor.slice_assign(ranges, tensor)
}
fn pad_reflect<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
padding: &[(usize, usize); D],
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
assert!(
before < dims[i] && after < dims[i],
"Reflect padding ({}, {}) must be less than dimension {} size ({})",
before,
after,
i,
dims[i]
);
}
}
let mut result = tensor;
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
result = pad_reflect_dim(result, i, before, after);
}
}
result
}
fn pad_reflect_dim<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
dim: usize,
pad_before: usize,
pad_after: usize,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
let dim_size = dims[dim];
let mut output_dims = dims;
output_dims[dim] += pad_before + pad_after;
let output = Tensor::zeros(output_dims, &tensor.device());
let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
let mut output = output.slice_assign(original_range, tensor.clone());
if pad_before > 0 {
let before_slice = tensor.clone().narrow(dim, 1, pad_before);
let before_flipped = before_slice.flip([dim as isize]);
let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
output = output.slice_assign(before_range, before_flipped);
}
if pad_after > 0 {
let start = dim_size - pad_after - 1;
let after_slice = tensor.narrow(dim, start, pad_after);
let after_flipped = after_slice.flip([dim as isize]);
let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
output = output.slice_assign(after_range, after_flipped);
}
output
}
fn pad_edge<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
padding: &[(usize, usize); D],
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
assert!(
dims[i] > 0,
"Cannot apply edge padding to zero-sized dimension {}",
i
);
}
}
let mut result = tensor;
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
result = pad_edge_dim(result, i, before, after);
}
}
result
}
fn pad_edge_dim<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
dim: usize,
pad_before: usize,
pad_after: usize,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
let dim_size = dims[dim];
let mut output_dims = dims;
output_dims[dim] += pad_before + pad_after;
let output = Tensor::zeros(output_dims, &tensor.device());
let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
let mut output = output.slice_assign(original_range, tensor.clone());
if pad_before > 0 {
let first_slice = tensor.clone().narrow(dim, 0, 1);
let before_pad = first_slice.repeat_dim(dim, pad_before);
let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
output = output.slice_assign(before_range, before_pad);
}
if pad_after > 0 {
let last_slice = tensor.narrow(dim, dim_size - 1, 1);
let after_pad = last_slice.repeat_dim(dim, pad_after);
let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
output = output.slice_assign(after_range, after_pad);
}
output
}