use crate::DType;
use crate::signal::traits::nd_filters::BoundaryMode;
use numr::error::{Error, Result};
use numr::ops::{ScalarOps, ShapeOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn pad_axis_impl<R, C>(
client: &C,
input: &Tensor<R>,
axis: isize,
pad_before: usize,
pad_after: usize,
mode: BoundaryMode,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + ShapeOps<R> + TensorOps<R> + RuntimeClient<R>,
{
if pad_before == 0 && pad_after == 0 {
return input.contiguous();
}
let ndim = input.ndim() as isize;
let axis_normalized = if axis < 0 {
(ndim + axis) as usize
} else {
axis as usize
};
if axis_normalized >= input.ndim() {
return Err(Error::InvalidArgument {
arg: "axis",
reason: format!(
"axis {} out of range for tensor with ndim {}",
axis,
input.ndim()
),
});
}
let axis_len = input.shape()[axis_normalized];
match mode {
BoundaryMode::Constant(value) => {
let mut padding = vec![0usize; input.ndim() * 2];
let dim_idx = input.ndim() - axis_normalized - 1;
padding[dim_idx * 2] = pad_before;
padding[dim_idx * 2 + 1] = pad_after;
client.pad(input, &padding, value)
}
BoundaryMode::Reflect => {
let mut parts: Vec<Tensor<R>> = Vec::new();
if pad_before > 0 {
let take = pad_before.min(axis_len.saturating_sub(1));
if take > 0 {
let slice = input.narrow(axis, 1, take)?;
let flipped = slice.flip(axis)?;
parts.push(flipped);
}
}
parts.push(input.contiguous()?);
if pad_after > 0 {
let take = pad_after.min(axis_len.saturating_sub(1));
if take > 0 {
let start = axis_len.saturating_sub(take + 1);
let slice = input.narrow(axis, start, take)?;
let flipped = slice.flip(axis)?;
parts.push(flipped);
}
}
if parts.is_empty() {
Ok(input.contiguous()?)
} else {
let refs: Vec<&Tensor<R>> = parts.iter().collect();
client.cat(&refs, axis)
}
}
BoundaryMode::Nearest => {
let mut parts: Vec<Tensor<R>> = Vec::new();
if pad_before > 0 {
let first = input.narrow(axis, 0, 1)?;
let mut repeat_shape = vec![1usize; input.ndim()];
repeat_shape[axis_normalized] = pad_before;
let repeated = client.repeat(&first, &repeat_shape)?;
parts.push(repeated);
}
parts.push(input.contiguous()?);
if pad_after > 0 {
let last = input.narrow(axis, axis_len - 1, 1)?;
let mut repeat_shape = vec![1usize; input.ndim()];
repeat_shape[axis_normalized] = pad_after;
let repeated = client.repeat(&last, &repeat_shape)?;
parts.push(repeated);
}
if parts.is_empty() {
Ok(input.contiguous()?)
} else {
let refs: Vec<&Tensor<R>> = parts.iter().collect();
client.cat(&refs, axis)
}
}
BoundaryMode::Mirror => {
let mut parts: Vec<Tensor<R>> = Vec::new();
if pad_before > 0 {
let take = pad_before.min(axis_len);
if take > 0 {
let slice = input.narrow(axis, 0, take)?;
let flipped = slice.flip(axis)?;
parts.push(flipped);
}
}
parts.push(input.contiguous()?);
if pad_after > 0 {
let take = pad_after.min(axis_len);
if take > 0 {
let start = axis_len.saturating_sub(take);
let slice = input.narrow(axis, start, take)?;
let flipped = slice.flip(axis)?;
parts.push(flipped);
}
}
if parts.is_empty() {
Ok(input.contiguous()?)
} else {
let refs: Vec<&Tensor<R>> = parts.iter().collect();
client.cat(&refs, axis)
}
}
BoundaryMode::Wrap => {
let mut parts: Vec<Tensor<R>> = Vec::new();
if pad_before > 0 {
let take = pad_before.min(axis_len);
if take > 0 {
let start = axis_len - take;
let slice = input.narrow(axis, start, take)?;
parts.push(slice);
}
}
parts.push(input.contiguous()?);
if pad_after > 0 {
let take = pad_after.min(axis_len);
if take > 0 {
let slice = input.narrow(axis, 0, take)?;
parts.push(slice);
}
}
if parts.is_empty() {
Ok(input.contiguous()?)
} else {
let refs: Vec<&Tensor<R>> = parts.iter().collect();
client.cat(&refs, axis)
}
}
}
}