use alloc::vec::Vec;
use core::ops::Range;
use crate::{Element, ElementConversion, Tensor, backend::Backend, ops::PadMode};
use super::Numeric;
fn build_slice_ranges<const D: usize>(
dims: [usize; D],
target_dim: usize,
start: usize,
len: usize,
) -> [Range<usize>; D] {
dims.iter()
.enumerate()
.map(|(i, &size)| {
if i == target_dim {
start..start + len
} else {
0..size
}
})
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap()
}
impl<B, const D: usize, K> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
pub fn pad(self, padding: (usize, usize, usize, usize), mode: impl Into<PadMode>) -> Self {
match mode.into() {
PadMode::Constant(value) => pad_constant(self, padding, value),
PadMode::Reflect => pad_reflect(self, padding),
PadMode::Edge => pad_edge(self, padding),
}
}
}
pub fn pad_constant<B, const D: usize, K, E>(
tensor: Tensor<B, D, K>,
padding: (usize, usize, usize, usize),
value: E,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
E: ElementConversion,
{
let (left, right, top, bottom) = padding;
let mut padded_dims: [usize; D] = tensor.dims();
padded_dims[D - 2] += top + bottom;
padded_dims[D - 1] += left + right;
let ranges: [core::ops::Range<usize>; D] = padded_dims
.iter()
.enumerate()
.map(|(i, &dim)| {
if i == D - 2 {
top..dim - bottom
} else if i == D - 1 {
left..dim - right
} else {
0..dim
}
})
.collect::<Vec<core::ops::Range<usize>>>()
.try_into()
.unwrap();
let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
padded_tensor.slice_assign(ranges, tensor)
}
pub fn pad_reflect<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
padding: (usize, usize, usize, usize),
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let (left, right, top, bottom) = padding;
let dims = tensor.dims();
assert!(
top < dims[D - 2] && bottom < dims[D - 2],
"Reflect padding on height ({}, {}) must be less than height dimension ({})",
top,
bottom,
dims[D - 2]
);
assert!(
left < dims[D - 1] && right < dims[D - 1],
"Reflect padding on width ({}, {}) must be less than width dimension ({})",
left,
right,
dims[D - 1]
);
let mut result = tensor;
if top > 0 || bottom > 0 {
result = pad_reflect_dim(result, D - 2, top, bottom);
}
if left > 0 || right > 0 {
result = pad_reflect_dim(result, D - 1, left, right);
}
result
}
fn pad_reflect_dim<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
dim: usize,
pad_before: usize,
pad_after: usize,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
let dim_size = dims[dim];
let mut output_dims = dims;
output_dims[dim] += pad_before + pad_after;
let output = Tensor::zeros(output_dims, &tensor.device());
let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
let mut output = output.slice_assign(original_range, tensor.clone());
if pad_before > 0 {
let before_slice = tensor.clone().narrow(dim, 1, pad_before);
let before_flipped = before_slice.flip([dim as isize]);
let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
output = output.slice_assign(before_range, before_flipped);
}
if pad_after > 0 {
let start = dim_size - pad_after - 1;
let after_slice = tensor.narrow(dim, start, pad_after);
let after_flipped = after_slice.flip([dim as isize]);
let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
output = output.slice_assign(after_range, after_flipped);
}
output
}
pub fn pad_edge<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
padding: (usize, usize, usize, usize),
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let (left, right, top, bottom) = padding;
let dims = tensor.dims();
if top > 0 || bottom > 0 {
assert!(
dims[D - 2] > 0,
"Cannot apply edge padding to zero-sized height dimension"
);
}
if left > 0 || right > 0 {
assert!(
dims[D - 1] > 0,
"Cannot apply edge padding to zero-sized width dimension"
);
}
let mut result = tensor;
if top > 0 || bottom > 0 {
result = pad_edge_dim(result, D - 2, top, bottom);
}
if left > 0 || right > 0 {
result = pad_edge_dim(result, D - 1, left, right);
}
result
}
fn pad_edge_dim<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
dim: usize,
pad_before: usize,
pad_after: usize,
) -> Tensor<B, D, K>
where
B: Backend,
K: Numeric<B>,
K::Elem: Element,
{
let dims = tensor.dims();
let dim_size = dims[dim];
let mut output_dims = dims;
output_dims[dim] += pad_before + pad_after;
let output = Tensor::zeros(output_dims, &tensor.device());
let original_range = build_slice_ranges(output_dims, dim, pad_before, dim_size);
let mut output = output.slice_assign(original_range, tensor.clone());
if pad_before > 0 {
let first_slice = tensor.clone().narrow(dim, 0, 1);
let before_pad = first_slice.repeat_dim(dim, pad_before);
let before_range = build_slice_ranges(output_dims, dim, 0, pad_before);
output = output.slice_assign(before_range, before_pad);
}
if pad_after > 0 {
let last_slice = tensor.narrow(dim, dim_size - 1, 1);
let after_pad = last_slice.repeat_dim(dim, pad_after);
let after_range = build_slice_ranges(output_dims, dim, pad_before + dim_size, pad_after);
output = output.slice_assign(after_range, after_pad);
}
output
}