use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Zero;
pub fn rollaxis<T: Clone + Zero>(array: &Array<T>, axis: usize, start: usize) -> Result<Array<T>> {
let shape = array.shape();
let ndim = shape.len();
if axis >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis, ndim
)));
}
if start > ndim {
return Err(NumRs2Error::InvalidOperation(format!(
"Start position {} exceeds array dimensions {}",
start, ndim
)));
}
if axis == start || (axis == ndim - 1 && start == ndim) {
return Ok(array.clone());
}
let mut axes: Vec<usize> = (0..ndim).collect();
let rolled_axis = axes.remove(axis);
axes.insert(if start <= axis { start } else { start - 1 }, rolled_axis);
let source_shape = array.shape().to_vec();
let mut target_shape = Vec::with_capacity(ndim);
for &ax in &axes {
target_shape.push(source_shape[ax]);
}
let mut result_data = vec![T::zero(); array.size()];
let source_size = array.size();
let source_array = array.array();
for i in 0..source_size {
let mut source_indices = vec![0; ndim];
let mut remainder = i;
for j in (0..ndim).rev() {
source_indices[j] = remainder % source_shape[j];
remainder /= source_shape[j];
}
let mut target_indices = vec![0; ndim];
for (j, &ax) in axes.iter().enumerate() {
target_indices[j] = source_indices[ax];
}
let mut target_flat_index = 0;
let mut multiplier = 1;
for j in (0..ndim).rev() {
target_flat_index += target_indices[j] * multiplier;
multiplier *= target_shape[j];
}
result_data[target_flat_index] = source_array
.as_slice()
.expect("source array should be contiguous and sliceable")[i]
.clone();
}
let result = Array::from_vec(result_data).reshape(&target_shape);
Ok(result)
}
pub fn swapaxes<T: Clone>(array: &Array<T>, axis1: usize, axis2: usize) -> Result<Array<T>> {
let ndim = array.ndim();
if axis1 >= ndim || axis2 >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axes {} and {} are out of bounds for array of dimension {}",
axis1, axis2, ndim
)));
}
if axis1 == axis2 {
return Ok(array.clone());
}
let mut permutation = Vec::with_capacity(ndim);
for i in 0..ndim {
if i == axis1 {
permutation.push(axis2);
} else if i == axis2 {
permutation.push(axis1);
} else {
permutation.push(i);
}
}
let mut result = array.clone();
for i in 0..ndim {
if permutation[i] != i {
let j = permutation[i];
result = result.transpose_axis(i, j);
permutation.swap(i, j);
}
}
Ok(result)
}
pub fn moveaxis<T: Clone>(
array: &Array<T>,
source: &[usize],
destination: &[usize],
) -> Result<Array<T>> {
let ndim = array.ndim();
if source.len() != destination.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Source and destination arrays must have the same length, got {} and {}",
source.len(),
destination.len()
)));
}
for &axis in source.iter().chain(destination.iter()) {
if axis >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array of dimension {}",
axis, ndim
)));
}
}
let mut perm = Vec::with_capacity(ndim);
for i in 0..ndim {
perm.push(i);
}
for (&src, &dst) in source.iter().zip(destination.iter()) {
let src_axis = perm.remove(src);
if dst < perm.len() {
perm.insert(dst, src_axis);
} else {
perm.push(src_axis);
}
}
let mut result = array.clone();
for i in 0..ndim {
if perm[i] != i {
let j = perm.iter().position(|&p| p == i).expect(
"axis i should exist in permutation array as perm contains all axes 0..ndim",
);
result = result.transpose_axis(i, j);
perm.swap(i, j);
}
}
Ok(result)
}