use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Zero;
use scirs2_core::ndarray::IxDyn;
use std::cmp;
pub fn diag<T: Clone + Zero>(array: &Array<T>, k: impl Into<Option<isize>>) -> Result<Array<T>> {
let k = k.into().unwrap_or(0);
let ndim = array.ndim();
match ndim {
1 => {
let size = array.size();
let diag_size = size + k.unsigned_abs();
let result = Array::zeros(&[diag_size, diag_size]);
let mut result_vec = result.to_vec();
let array_vec = array.to_vec();
#[allow(clippy::needless_range_loop)]
for i in 0..size {
let row: usize;
let col: usize;
if k >= 0 {
row = i;
col = i + k as usize;
} else {
row = i + (-k) as usize;
col = i;
}
if row < diag_size && col < diag_size {
let idx = row * diag_size + col;
result_vec[idx] = array_vec[i].clone();
}
}
Ok(Array::from_vec(result_vec).reshape(&[diag_size, diag_size]))
}
2 => {
let shape = array.shape();
if shape.len() != 2 {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected a 2D array, got shape {:?}",
shape
)));
}
let rows = shape[0];
let cols = shape[1];
let diag_len = if k >= 0 {
cmp::min(rows, cols.saturating_sub(k as usize))
} else {
cmp::min(rows.saturating_sub((-k) as usize), cols)
};
if diag_len == 0 {
return Ok(Array::zeros(&[0]));
}
let mut result = Vec::with_capacity(diag_len);
let array_vec = array.to_vec();
for i in 0..diag_len {
let row: usize;
let col: usize;
if k >= 0 {
row = i;
col = i + k as usize;
} else {
row = i + (-k) as usize;
col = i;
}
if row < rows && col < cols {
let idx = row * cols + col;
result.push(array_vec[idx].clone());
}
}
Ok(Array::from_vec(result))
}
_ => Err(NumRs2Error::InvalidOperation(format!(
"Input must be 1D or 2D array, got {}D array",
ndim
))),
}
}
pub fn diagonal<T: Clone + num_traits::Zero>(
array: &Array<T>,
offset: impl Into<Option<isize>>,
axis1: impl Into<Option<usize>>,
axis2: impl Into<Option<usize>>,
) -> Result<Array<T>> {
let offset = offset.into().unwrap_or(0);
let axis1 = axis1.into().unwrap_or(0);
let axis2 = axis2.into().unwrap_or(1);
let ndim = array.ndim();
if ndim < 2 {
return Err(NumRs2Error::InvalidOperation(format!(
"Array must be at least 2D, got {}D array",
ndim
)));
}
if axis1 == axis2 {
return Err(NumRs2Error::InvalidOperation(format!(
"axis1 and axis2 cannot be the same: {}",
axis1
)));
}
if axis1 >= ndim || axis2 >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axes ({}, {}) out of bounds for array of dimension {}",
axis1, axis2, ndim
)));
}
let shape = array.shape();
let axis1_len = shape[axis1];
let axis2_len = shape[axis2];
let diag_len = if offset >= 0 {
cmp::min(axis1_len, axis2_len.saturating_sub(offset as usize))
} else {
cmp::min(axis1_len.saturating_sub((-offset) as usize), axis2_len)
};
if diag_len == 0 {
let mut result_shape = Vec::with_capacity(ndim - 1);
for (i, &dim) in shape.iter().enumerate() {
if i != axis1 && i != axis2 {
result_shape.push(dim);
}
}
result_shape.push(0);
return Ok(Array::zeros(&result_shape));
}
let mut result_shape = Vec::with_capacity(ndim - 1);
for (i, &dim) in shape.iter().enumerate() {
if i != axis1 && i != axis2 {
result_shape.push(dim);
}
}
result_shape.push(diag_len);
let result_size: usize = result_shape.iter().product();
let mut result_vec = Vec::with_capacity(result_size);
let array_vec = array.to_vec();
let mut strides = Vec::with_capacity(ndim);
let mut stride = 1;
for &dim in shape.iter().rev() {
strides.push(stride);
stride *= dim;
}
strides.reverse();
let axis1_stride = strides[axis1];
let axis2_stride = strides[axis2];
let calc_base_index = |indices: &[usize]| -> usize {
let mut base_idx = 0;
let mut _dst_idx = 0;
for (src_idx, &dim) in indices.iter().enumerate() {
if src_idx != axis1 && src_idx != axis2 {
base_idx += dim * strides[src_idx];
_dst_idx += 1;
}
}
base_idx
};
let mut indices = vec![0; ndim];
let increment_indices = |indices: &mut [usize], shape: &[usize], axis1, axis2| {
for i in (0..indices.len()).rev() {
if i != axis1 && i != axis2 {
indices[i] += 1;
if indices[i] < shape[i] {
return true;
}
indices[i] = 0;
}
}
false
};
let mut outer_elements = 1;
for (i, &dim) in shape.iter().enumerate() {
if i != axis1 && i != axis2 {
outer_elements *= dim;
}
}
for _ in 0..outer_elements {
let base_idx = calc_base_index(&indices);
for i in 0..diag_len {
let row: usize;
let col: usize;
if offset >= 0 {
row = i;
col = i + offset as usize;
} else {
row = i + (-offset) as usize;
col = i;
}
if row < axis1_len && col < axis2_len {
let idx = base_idx + row * axis1_stride + col * axis2_stride;
result_vec.push(array_vec[idx].clone());
}
}
increment_indices(&mut indices, &shape, axis1, axis2);
}
Ok(Array::from_vec(result_vec).reshape(&result_shape))
}
pub fn rollaxis<T: Clone + Zero>(array: &Array<T>, axis: usize, start: usize) -> Result<Array<T>> {
let shape = array.shape();
let ndim = shape.len();
if axis >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
axis, ndim
)));
}
if start > ndim {
return Err(NumRs2Error::InvalidOperation(format!(
"Start position {} exceeds array dimensions {}",
start, ndim
)));
}
if axis == start || (axis == ndim - 1 && start == ndim) {
return Ok(array.clone());
}
let mut axes: Vec<usize> = (0..ndim).collect();
let rolled_axis = axes.remove(axis);
axes.insert(if start <= axis { start } else { start - 1 }, rolled_axis);
let source_shape = array.shape().to_vec();
let mut target_shape = Vec::with_capacity(ndim);
for &ax in &axes {
target_shape.push(source_shape[ax]);
}
let mut result_data = vec![T::zero(); array.size()];
let source_size = array.size();
let source_array = array.array();
for i in 0..source_size {
let mut source_indices = vec![0; ndim];
let mut remainder = i;
for j in (0..ndim).rev() {
source_indices[j] = remainder % source_shape[j];
remainder /= source_shape[j];
}
let mut target_indices = vec![0; ndim];
for (j, &ax) in axes.iter().enumerate() {
target_indices[j] = source_indices[ax];
}
let mut target_flat_index = 0;
let mut multiplier = 1;
for j in (0..ndim).rev() {
target_flat_index += target_indices[j] * multiplier;
multiplier *= target_shape[j];
}
if let Some(slice) = source_array.as_slice() {
result_data[target_flat_index] = slice[i].clone();
} else {
return Err(NumRs2Error::InvalidOperation(
"Failed to get array slice".into(),
));
}
}
let result = Array::from_vec(result_data).reshape(&target_shape);
Ok(result)
}
#[allow(dead_code)]
fn array_transpose<T: Clone + Zero>(
array: &Array<T>,
axis1: usize,
axis2: usize,
) -> Result<Array<T>> {
let shape = array.shape();
let ndim = shape.len();
if axis1 >= ndim || axis2 >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axes ({}, {}) out of bounds for array of dimension {}",
axis1, axis2, ndim
)));
}
if axis1 == axis2 {
return Ok(array.clone());
}
let mut transposed_shape = shape.clone();
transposed_shape.swap(axis1, axis2);
let mut result = Array::zeros(&transposed_shape);
let total_size = array.size();
for i in 0..total_size {
let mut indices = Vec::with_capacity(ndim);
let mut temp = i;
for j in (0..ndim).rev() {
indices.insert(0, temp % shape[j]);
temp /= shape[j];
}
let mut trans_indices = indices.clone();
trans_indices.swap(axis1, axis2);
if let Some(value) = array.array().get(IxDyn(&indices)) {
if result.set(&trans_indices, value.clone()).is_err() {
return Err(NumRs2Error::InvalidOperation(
"Failed to set transposed value".into(),
));
}
}
}
Ok(result)
}
pub fn swapaxes<T: Clone>(array: &Array<T>, axis1: usize, axis2: usize) -> Result<Array<T>> {
let ndim = array.ndim();
if axis1 >= ndim || axis2 >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axes {} and {} are out of bounds for array of dimension {}",
axis1, axis2, ndim
)));
}
if axis1 == axis2 {
return Ok(array.clone());
}
let mut permutation = Vec::with_capacity(ndim);
for i in 0..ndim {
if i == axis1 {
permutation.push(axis2);
} else if i == axis2 {
permutation.push(axis1);
} else {
permutation.push(i);
}
}
let mut result = array.clone();
for i in 0..ndim {
if permutation[i] != i {
let j = permutation[i];
result = result.transpose_axis(i, j);
permutation.swap(i, j);
}
}
Ok(result)
}
pub fn moveaxis<T: Clone>(
array: &Array<T>,
source: &[usize],
destination: &[usize],
) -> Result<Array<T>> {
let ndim = array.ndim();
if source.len() != destination.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Source and destination arrays must have the same length, got {} and {}",
source.len(),
destination.len()
)));
}
for &axis in source.iter().chain(destination.iter()) {
if axis >= ndim {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array of dimension {}",
axis, ndim
)));
}
}
let mut perm = Vec::with_capacity(ndim);
for i in 0..ndim {
perm.push(i);
}
for (&src, &dst) in source.iter().zip(destination.iter()) {
let src_axis = perm.remove(src);
if dst < perm.len() {
perm.insert(dst, src_axis);
} else {
perm.push(src_axis);
}
}
let mut result = array.clone();
for i in 0..ndim {
if perm[i] != i {
if let Some(j) = perm.iter().position(|&p| p == i) {
result = result.transpose_axis(i, j);
perm.swap(i, j);
}
}
}
Ok(result)
}
pub fn atleast_1d<T: Clone + num_traits::Zero>(arys: &[&Array<T>]) -> Result<Vec<Array<T>>> {
let mut result = Vec::with_capacity(arys.len());
for &array in arys {
if array.ndim() == 0 {
if let Ok(scalar_value) = array.get(&[]) {
result.push(Array::from_vec(vec![scalar_value]).reshape(&[1]));
} else {
return Err(NumRs2Error::InvalidOperation(
"Failed to get scalar value".into(),
));
}
} else {
result.push(array.clone());
}
}
Ok(result)
}
pub fn atleast_2d<T: Clone + num_traits::Zero>(arys: &[&Array<T>]) -> Result<Vec<Array<T>>> {
let mut result = Vec::with_capacity(arys.len());
for &array in arys {
if array.ndim() == 0 {
if let Ok(scalar_value) = array.get(&[]) {
result.push(Array::from_vec(vec![scalar_value]).reshape(&[1, 1]));
} else {
return Err(NumRs2Error::InvalidOperation(
"Failed to get scalar value".into(),
));
}
} else if array.ndim() == 1 {
let data = array.to_vec();
let new_shape = vec![1, data.len()];
result.push(Array::from_vec(data).reshape(&new_shape));
} else {
result.push(array.clone());
}
}
Ok(result)
}
pub fn atleast_3d<T: Clone + num_traits::Zero>(arys: &[&Array<T>]) -> Result<Vec<Array<T>>> {
let mut result = Vec::with_capacity(arys.len());
for &array in arys {
if array.ndim() == 0 {
if let Ok(scalar_value) = array.get(&[]) {
result.push(Array::from_vec(vec![scalar_value]).reshape(&[1, 1, 1]));
} else {
return Err(NumRs2Error::InvalidOperation(
"Failed to get scalar value".into(),
));
}
} else if array.ndim() == 1 {
let data = array.to_vec();
let new_shape = vec![1, data.len(), 1];
result.push(Array::from_vec(data).reshape(&new_shape));
} else if array.ndim() == 2 {
let data = array.to_vec();
let shape = array.shape();
let new_shape = vec![shape[0], shape[1], 1];
result.push(Array::from_vec(data).reshape(&new_shape));
} else {
result.push(array.clone());
}
}
Ok(result)
}
pub fn broadcast_arrays<T: Clone>(arrays: &[&Array<T>]) -> Result<Vec<Array<T>>> {
if arrays.is_empty() {
return Ok(vec![]);
}
let mut broadcast_shape = arrays[0].shape();
for arr in arrays.iter().skip(1) {
broadcast_shape = Array::<T>::broadcast_shape(&broadcast_shape, &arr.shape())?;
}
let mut result = Vec::with_capacity(arrays.len());
for &arr in arrays {
let broadcasted = arr.broadcast_to(&broadcast_shape)?;
result.push(broadcasted);
}
Ok(result)
}
pub fn broadcast_to<T: Clone>(array: &Array<T>, shape: &[usize]) -> Result<Array<T>> {
array.broadcast_to(shape)
}