use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Zero;
pub fn pad<T>(
array: &Array<T>,
pad_width: &[(usize, usize)],
mode: &str,
constant_values: Option<T>,
) -> Result<Array<T>>
where
T: Clone + Zero,
{
let shape = array.shape();
if pad_width.len() != shape.len() {
return Err(NumRs2Error::InvalidOperation(format!(
"pad_width must have same length as array dimensions. Got {} for {} dimensions",
pad_width.len(),
shape.len()
)));
}
let mut new_shape = Vec::with_capacity(shape.len());
for (i, &dim) in shape.iter().enumerate() {
let (before, after) = pad_width[i];
new_shape.push(before + dim + after);
}
let pad_value = match mode {
"constant" => constant_values.unwrap_or_else(T::zero),
_ => T::zero(), };
let total_size: usize = new_shape.iter().product();
let mut result_data = vec![pad_value.clone(); total_size];
let mut old_strides = vec![1; shape.len()];
let mut new_strides = vec![1; new_shape.len()];
for i in (0..shape.len() - 1).rev() {
old_strides[i] = old_strides[i + 1] * shape[i + 1];
new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
}
let original_data = array.to_vec();
for i in 0..original_data.len() {
let mut old_indices = vec![0; shape.len()];
let mut temp = i;
for j in 0..shape.len() {
old_indices[j] = temp / old_strides[j];
temp %= old_strides[j];
}
let mut new_indices = vec![0; new_shape.len()];
for j in 0..shape.len() {
new_indices[j] = old_indices[j] + pad_width[j].0;
}
let mut new_flat_idx = 0;
for j in 0..new_shape.len() {
new_flat_idx += new_indices[j] * new_strides[j];
}
result_data[new_flat_idx] = original_data[i].clone();
}
match mode {
"constant" => {
}
"edge" => {
for axis in 0..shape.len() {
let (before, after) = pad_width[axis];
if before > 0 {
for i in 0..total_size {
let indices = index_from_flat(i, &new_shape, &new_strides);
if indices[axis] < before {
let mut source_indices = indices.clone();
source_indices[axis] = before;
let source_flat = flat_from_index(&source_indices, &new_strides);
result_data[i] = result_data[source_flat].clone();
}
}
}
if after > 0 {
for i in 0..total_size {
let indices = index_from_flat(i, &new_shape, &new_strides);
if indices[axis] >= before + shape[axis] {
let mut source_indices = indices.clone();
source_indices[axis] = before + shape[axis] - 1;
let source_flat = flat_from_index(&source_indices, &new_strides);
result_data[i] = result_data[source_flat].clone();
}
}
}
}
}
"reflect" => {
for axis in 0..shape.len() {
#[allow(unused_variables)]
let (before, after) = pad_width[axis];
let axis_size = shape[axis];
for i in 0..total_size {
let indices = index_from_flat(i, &new_shape, &new_strides);
if indices[axis] < before {
let offset = before - indices[axis];
let reflected_pos = if offset < axis_size {
before + offset
} else {
let period = 2 * (axis_size - 1);
let _cycles = offset / period;
let remainder = offset % period;
if remainder < axis_size {
before + remainder
} else {
before + 2 * (axis_size - 1) - remainder
}
};
let mut source_indices = indices.clone();
source_indices[axis] = reflected_pos;
let source_flat = flat_from_index(&source_indices, &new_strides);
result_data[i] = result_data[source_flat].clone();
}
}
for i in 0..total_size {
let indices = index_from_flat(i, &new_shape, &new_strides);
if indices[axis] >= before + axis_size {
let offset = indices[axis] - (before + axis_size - 1);
let reflected_pos = if offset < axis_size {
before + axis_size - 1 - offset
} else {
let period = 2 * (axis_size - 1);
let _cycles = offset / period;
let remainder = offset % period;
if remainder < axis_size {
before + axis_size - 1 - remainder
} else {
before + remainder - (axis_size - 1)
}
};
let mut source_indices = indices.clone();
source_indices[axis] = reflected_pos;
let source_flat = flat_from_index(&source_indices, &new_strides);
result_data[i] = result_data[source_flat].clone();
}
}
}
}
"symmetric" => {
for axis in 0..shape.len() {
#[allow(unused_variables)]
let (before, after) = pad_width[axis];
let axis_size = shape[axis];
for i in 0..total_size {
let indices = index_from_flat(i, &new_shape, &new_strides);
if indices[axis] < before {
let offset = before - indices[axis] - 1;
let reflected_pos = if offset < axis_size {
before + offset
} else {
let period = 2 * axis_size;
let _cycles = offset / period;
let remainder = offset % period;
if remainder < axis_size {
before + remainder
} else {
before + 2 * axis_size - remainder - 1
}
};
let mut source_indices = indices.clone();
source_indices[axis] = reflected_pos;
let source_flat = flat_from_index(&source_indices, &new_strides);
result_data[i] = result_data[source_flat].clone();
}
}
for i in 0..total_size {
let indices = index_from_flat(i, &new_shape, &new_strides);
if indices[axis] >= before + axis_size {
let offset = indices[axis] - (before + axis_size);
let reflected_pos = if offset < axis_size {
before + axis_size - 1 - offset
} else {
let period = 2 * axis_size;
let _cycles = offset / period;
let remainder = offset % period;
if remainder < axis_size {
before + axis_size - 1 - remainder
} else {
before + remainder - axis_size
}
};
let mut source_indices = indices.clone();
source_indices[axis] = reflected_pos;
let source_flat = flat_from_index(&source_indices, &new_strides);
result_data[i] = result_data[source_flat].clone();
}
}
}
}
"wrap" => {
for axis in 0..shape.len() {
#[allow(unused_variables)]
let (before, after) = pad_width[axis];
let axis_size = shape[axis];
for i in 0..total_size {
let indices = index_from_flat(i, &new_shape, &new_strides);
if indices[axis] < before {
let offset = before - indices[axis];
let wrapped_pos = before + axis_size - (offset % axis_size);
let mut source_indices = indices.clone();
source_indices[axis] = wrapped_pos;
let source_flat = flat_from_index(&source_indices, &new_strides);
result_data[i] = result_data[source_flat].clone();
}
}
for i in 0..total_size {
let indices = index_from_flat(i, &new_shape, &new_strides);
if indices[axis] >= before + axis_size {
let offset = indices[axis] - (before + axis_size);
let wrapped_pos = before + (offset % axis_size);
let mut source_indices = indices.clone();
source_indices[axis] = wrapped_pos;
let source_flat = flat_from_index(&source_indices, &new_strides);
result_data[i] = result_data[source_flat].clone();
}
}
}
}
_ => {
return Err(NumRs2Error::InvalidOperation(format!(
"Unknown pad mode: {}. Must be one of: constant, edge, reflect, symmetric, wrap",
mode
)));
}
}
Ok(Array::from_vec(result_data).reshape(&new_shape))
}
fn index_from_flat(flat_idx: usize, shape: &[usize], strides: &[usize]) -> Vec<usize> {
let mut indices = vec![0; shape.len()];
let mut temp = flat_idx;
for i in 0..shape.len() {
indices[i] = temp / strides[i];
temp %= strides[i];
}
indices
}
fn flat_from_index(indices: &[usize], strides: &[usize]) -> usize {
let mut flat_idx = 0;
for i in 0..indices.len() {
flat_idx += indices[i] * strides[i];
}
flat_idx
}