use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::Order;
pub fn roll<T: Clone + Send + Sync>(
array: &Array<T>,
shift: isize,
axis: Option<usize>,
) -> Result<Array<T>> {
use scirs2_core::parallel_ops::*;
const PARALLEL_THRESHOLD: usize = 10000;
if array.size() == 0 {
return Ok(array.clone());
}
let shape = array.shape();
match axis {
Some(ax) => {
if ax >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
ax,
shape.len()
)));
}
let axis_size = shape[ax];
if axis_size <= 1 {
return Ok(array.clone());
}
let shift_mod =
((shift % axis_size as isize) + axis_size as isize) % axis_size as isize;
if shift_mod == 0 {
return Ok(array.clone());
}
let pre_axis_size: usize = shape.iter().take(ax).product();
let post_axis_size: usize = shape.iter().skip(ax + 1).product();
let array_vec = array.to_vec();
let total_size = array_vec.len();
if is_parallel_enabled() && total_size >= PARALLEL_THRESHOLD {
let result_vec: Vec<T> = (0..total_size)
.into_par_iter()
.map(|dst_idx| {
let pre_stride = axis_size * post_axis_size;
let i_pre = dst_idx / pre_stride;
let remainder = dst_idx % pre_stride;
let dst_axis_idx = remainder / post_axis_size;
let i_post = remainder % post_axis_size;
let src_axis_idx =
(dst_axis_idx + axis_size - shift_mod as usize) % axis_size;
let src_idx = i_pre * pre_stride + src_axis_idx * post_axis_size + i_post;
array_vec[src_idx].clone()
})
.collect();
Ok(Array::from_vec(result_vec).reshape(&shape))
} else {
let first_elem = array
.array()
.first()
.ok_or_else(|| NumRs2Error::InvalidOperation("Array is empty".into()))?
.clone();
let mut result = Array::full(&shape, first_elem);
let result_vec = result.array_mut().as_slice_mut().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to get mutable slice".into())
})?;
let axis_stride = post_axis_size;
let pre_stride = axis_size * post_axis_size;
for i_pre in 0..pre_axis_size {
let base_pre = i_pre * pre_stride;
for i_axis in 0..axis_size {
let dst_axis_idx = (i_axis + shift_mod as usize) % axis_size;
let src_base = base_pre + i_axis * axis_stride;
let dst_base = base_pre + dst_axis_idx * axis_stride;
result_vec[dst_base..(post_axis_size + dst_base)]
.clone_from_slice(&array_vec[src_base..(post_axis_size + src_base)]);
}
}
Ok(result)
}
}
None => {
let array_vec = array.to_vec();
let size = array_vec.len();
if size <= 1 {
return Ok(array.clone());
}
let shift_mod = ((shift % size as isize) + size as isize) % size as isize;
let shift_usize = shift_mod as usize;
if shift_mod == 0 {
return Ok(array.clone());
}
if is_parallel_enabled() && size >= PARALLEL_THRESHOLD {
let result_vec: Vec<T> = (0..size)
.into_par_iter()
.map(|i| {
let src_idx = (i + size - shift_usize) % size;
array_vec[src_idx].clone()
})
.collect();
let result_array = Array::from_vec(result_vec);
Ok(result_array.reshape(&shape))
} else {
let mut result_vec = vec![array_vec[0].clone(); size];
for i in 0..size {
let dst_idx = (i + shift_usize) % size;
result_vec[dst_idx] = array_vec[i].clone();
}
let result_array = Array::from_vec(result_vec);
Ok(result_array.reshape(&shape))
}
}
}
}
pub fn flip<T: Clone>(array: &Array<T>, axis: Option<usize>) -> Result<Array<T>> {
if array.size() == 0 {
return Ok(array.clone());
}
let shape = array.shape();
match axis {
Some(ax) => {
if ax >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
ax,
shape.len()
)));
}
let axis_size = shape[ax];
if axis_size <= 1 {
return Ok(array.clone());
}
let first_elem = array
.array()
.first()
.ok_or_else(|| NumRs2Error::InvalidOperation("Array is empty".into()))?
.clone();
let mut result = Array::full(&shape, first_elem);
let pre_axis_size: usize = shape.iter().take(ax).product();
let post_axis_size: usize = shape.iter().skip(ax + 1).product();
let array_vec = array.to_vec();
let result_vec = result.array_mut().as_slice_mut().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to get mutable slice".into())
})?;
for i_pre in 0..pre_axis_size {
for i_axis in 0..axis_size {
for i_post in 0..post_axis_size {
let src_axis_idx = i_axis;
let src_idx = i_pre * (axis_size * post_axis_size)
+ src_axis_idx * post_axis_size
+ i_post;
let dst_axis_idx = axis_size - 1 - i_axis;
let dst_idx = i_pre * (axis_size * post_axis_size)
+ dst_axis_idx * post_axis_size
+ i_post;
result_vec[dst_idx] = array_vec[src_idx].clone();
}
}
}
Ok(result)
}
None => {
let mut result = array.clone();
for ax in 0..shape.len() {
result = flip(&result, Some(ax))?;
}
Ok(result)
}
}
}
pub fn flipud<T: Clone>(array: &Array<T>) -> Result<Array<T>> {
if array.ndim() == 0 {
return Err(NumRs2Error::InvalidOperation(
"Input must be at least 1-dimensional".into(),
));
}
flip(array, Some(0))
}
pub fn fliplr<T: Clone>(array: &Array<T>) -> Result<Array<T>> {
if array.ndim() < 2 {
return Err(NumRs2Error::InvalidOperation(
"Input must be at least 2-dimensional".into(),
));
}
flip(array, Some(1))
}
pub fn rot90<T: Clone>(
array: &Array<T>,
k: impl Into<Option<i32>>,
axes: impl Into<Option<(usize, usize)>>,
) -> Result<Array<T>> {
let k = k.into().unwrap_or(1);
let axes = axes.into().unwrap_or((0, 1));
let ndim = array.ndim();
if axes.0 >= ndim || axes.1 >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axes ({}, {}) out of bounds for array of dimension {}",
axes.0, axes.1, ndim
)));
}
if axes.0 == axes.1 {
return Err(NumRs2Error::InvalidOperation(format!(
"Axes ({}, {}) must be different",
axes.0, axes.1
)));
}
let k = ((k % 4) + 4) % 4;
if k == 0 {
return Ok(array.clone());
}
let mut result = array.clone();
if k == 2 {
result = flip(&result, Some(axes.0))?;
result = flip(&result, Some(axes.1))?;
return Ok(result);
}
result = result.transpose_axis(axes.0, axes.1);
if k == 1 {
result = flip(&result, Some(axes.0))?;
} else if k == 3 {
result = flip(&result, Some(axes.1))?;
}
Ok(result)
}
pub fn expand_dims<T: Clone>(array: &Array<T>, axis: usize) -> Result<Array<T>> {
let shape = array.shape();
if axis > shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis,
shape.len()
)));
}
let mut new_shape = shape.clone();
new_shape.insert(axis, 1);
Ok(array.reshape(&new_shape))
}
pub fn squeeze<T: Clone>(array: &Array<T>, axis: Option<usize>) -> Result<Array<T>> {
let shape = array.shape();
match axis {
Some(ax) => {
if ax >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
ax,
shape.len()
)));
}
if shape[ax] != 1 {
return Err(NumRs2Error::InvalidOperation(format!(
"Cannot squeeze axis {} with size {}",
ax, shape[ax]
)));
}
let mut new_shape = shape.clone();
new_shape.remove(ax);
Ok(array.reshape(&new_shape))
}
None => {
let new_shape: Vec<_> = shape.iter().filter(|&&s| s != 1).cloned().collect();
if new_shape.is_empty() {
Ok(array.reshape(&[1]))
} else {
Ok(array.reshape(&new_shape))
}
}
}
}
pub fn ravel<T: Clone>(array: &Array<T>, order: Option<char>) -> Result<Array<T>> {
let size = array.size();
if size == 0 {
return Ok(Array::from_vec(Vec::<T>::new()));
}
if array.ndim() == 1 {
return Ok(array.clone());
}
let order_val = order.unwrap_or('C');
let nd_order = match order_val {
'C' => Order::RowMajor,
'F' => Order::ColumnMajor,
_ => {
return Err(NumRs2Error::InvalidOperation(format!(
"Order must be 'C' or 'F', got '{}'",
order_val
)))
}
};
let flat_data = match nd_order {
Order::RowMajor => array.array().iter().cloned().collect::<Vec<_>>(),
Order::ColumnMajor => {
let transposed = array.transpose();
transposed.array().iter().cloned().collect::<Vec<_>>()
}
_ => {
return Err(NumRs2Error::InvalidOperation(
"Unsupported memory order".to_string(),
));
}
};
Ok(Array::from_vec(flat_data))
}
pub fn flatten<T: Clone>(array: &Array<T>, order: Option<char>) -> Result<Array<T>> {
ravel(array, order)
}