use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use scirs2_core::ndarray::IxDyn;
pub fn ix_<T: Clone>(arrays: &[&Array<T>]) -> Result<Vec<Array<T>>> {
if arrays.is_empty() {
return Ok(vec![]);
}
let n = arrays.len();
let mut result = Vec::with_capacity(n);
for (i, array) in arrays.iter().enumerate() {
let mut shape = vec![1; n];
shape[i] = array.size();
let reshaped = array.reshape(&shape);
result.push(reshaped);
}
Ok(result)
}
pub fn put<T: Clone + ToString>(
array: &mut Array<T>,
indices: &Array<T>,
values: &Array<T>,
mode: Option<&str>,
) -> Result<()> {
let indices_slice = indices.array().as_slice().ok_or_else(|| {
NumRs2Error::InvalidOperation("indices array should be contiguous".to_string())
})?;
for i in 0..indices.size() {
if indices_slice[i].to_string().parse::<isize>().is_err() {
return Err(NumRs2Error::InvalidOperation(
"Indices must be integers".to_string(),
));
}
}
let n_indices = indices.size();
let n_values = values.size();
if n_values < n_indices {
return Err(NumRs2Error::InvalidOperation(format!(
"Not enough values ({}) to put at all indices ({})",
n_values, n_indices
)));
}
let array_size = array.size();
let handle_mode = mode.unwrap_or("raise");
let values_slice = values.array().as_slice().ok_or_else(|| {
NumRs2Error::InvalidOperation("values array should be contiguous".to_string())
})?;
for i in 0..n_indices {
let idx_value = indices_slice[i]
.to_string()
.parse::<isize>()
.map_err(|_| NumRs2Error::InvalidOperation("index should be parseable".to_string()))?;
let idx = match handle_mode {
"raise" => {
if idx_value < 0 || idx_value >= array_size as isize {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for array with size {}",
idx_value, array_size
)));
}
idx_value as usize
}
"wrap" => {
(((idx_value % array_size as isize) + array_size as isize) % array_size as isize)
as usize
}
"clip" => {
if idx_value < 0 {
0
} else if idx_value >= array_size as isize {
array_size - 1
} else {
idx_value as usize
}
}
_ => {
return Err(NumRs2Error::InvalidOperation(format!(
"Invalid mode: {}. Must be one of 'raise', 'wrap', or 'clip'",
handle_mode
)));
}
};
let shape = array.shape();
let ndim = shape.len();
let mut multi_idx = Vec::with_capacity(ndim);
let mut temp = idx;
for j in (0..ndim).rev() {
if j == 0 {
multi_idx.insert(0, temp);
} else {
let prod: usize = shape[1..=j].iter().product();
multi_idx.insert(0, temp / prod);
temp %= prod;
}
}
let value = values_slice[i % n_values].clone();
array.set(&multi_idx, value)?;
}
Ok(())
}
pub fn putmask<T: Clone + ToString, U: Clone + ToString>(
array: &mut Array<T>,
mask: &Array<U>,
values: &Array<T>,
) -> Result<()> {
if array.shape() != mask.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: array.shape(),
actual: mask.shape(),
});
}
let mask_slice = mask.array().as_slice().ok_or_else(|| {
NumRs2Error::InvalidOperation("mask array should be contiguous".to_string())
})?;
for i in 0..mask.size() {
let val_str = mask_slice[i].to_string();
if val_str != "true" && val_str != "false" {
return Err(NumRs2Error::InvalidOperation(
"Mask must contain boolean values".to_string(),
));
}
}
let true_count = mask
.to_vec()
.iter()
.filter(|x| x.to_string() == "true")
.count();
let n_values = values.size();
if n_values == 0 && true_count > 0 {
return Err(NumRs2Error::InvalidOperation(
"No values provided to fill masked elements".to_string(),
));
}
let values_slice = values.array().as_slice().ok_or_else(|| {
NumRs2Error::InvalidOperation("values array should be contiguous".to_string())
})?;
let mut value_idx = 0;
for i in 0..array.size() {
let mask_val = mask_slice[i].to_string() == "true";
if mask_val {
let shape = array.shape();
let ndim = shape.len();
let mut multi_idx = Vec::with_capacity(ndim);
let mut temp = i;
for j in (0..ndim).rev() {
if j == 0 {
multi_idx.insert(0, temp);
} else {
let prod: usize = shape[1..=j].iter().product();
multi_idx.insert(0, temp / prod);
temp %= prod;
}
}
let value = values_slice[value_idx % n_values].clone();
array.set(&multi_idx, value)?;
value_idx += 1;
}
}
Ok(())
}
pub fn take<T: Clone + ToString + num_traits::Zero>(
array: &Array<T>,
indices: &Array<usize>,
axis: Option<usize>,
mode: Option<&str>,
) -> Result<Array<T>> {
let indices_slice = indices.array().as_slice().ok_or_else(|| {
NumRs2Error::InvalidOperation("indices array should be contiguous".to_string())
})?;
let handle_mode = mode.unwrap_or("raise");
let indices_vec: Vec<isize> = indices_slice.iter().map(|&x| x as isize).collect();
match axis {
None => {
let flat_data = array.to_vec();
let array_size = flat_data.len();
let mut result_data = Vec::with_capacity(indices_vec.len());
for &idx_value in &indices_vec {
let idx = match handle_mode {
"raise" => {
if idx_value < 0 || idx_value >= array_size as isize {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for array with size {}",
idx_value, array_size
)));
}
idx_value as usize
}
"wrap" => {
(((idx_value % array_size as isize) + array_size as isize)
% array_size as isize) as usize
}
"clip" => {
if idx_value < 0 {
0
} else if idx_value >= array_size as isize {
array_size - 1
} else {
idx_value as usize
}
}
_ => {
return Err(NumRs2Error::InvalidOperation(format!(
"Invalid mode: {}. Must be one of 'raise', 'wrap', or 'clip'",
handle_mode
)));
}
};
result_data.push(flat_data[idx].clone());
}
Ok(Array::from_vec(result_data))
}
Some(ax) => {
if ax >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array with {} dimensions",
ax,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[ax];
let mut out_shape = shape.clone();
out_shape[ax] = indices_vec.len();
let processed_indices: Result<Vec<usize>> = indices_vec
.iter()
.map(|&idx_value| match handle_mode {
"raise" => {
if idx_value < 0 || idx_value >= axis_size as isize {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for axis with size {}",
idx_value, axis_size
)));
}
Ok(idx_value as usize)
}
"wrap" => Ok((((idx_value % axis_size as isize) + axis_size as isize)
% axis_size as isize) as usize),
"clip" => {
if idx_value < 0 {
Ok(0)
} else if idx_value >= axis_size as isize {
Ok(axis_size - 1)
} else {
Ok(idx_value as usize)
}
}
_ => Err(NumRs2Error::InvalidOperation(format!(
"Invalid mode: {}. Must be one of 'raise', 'wrap', or 'clip'",
handle_mode
))),
})
.collect();
let processed_indices = processed_indices?;
let mut result_data = Vec::new();
let total_elements = out_shape.iter().product::<usize>();
for result_idx in 0..total_elements {
let mut coords = vec![0; array.ndim()];
let mut remaining = result_idx;
for i in (0..array.ndim()).rev() {
let size = out_shape[i];
coords[i] = remaining % size;
remaining /= size;
}
let original_axis_coord = processed_indices[coords[ax]];
let mut orig_coords = coords.clone();
orig_coords[ax] = original_axis_coord;
let mut orig_linear_idx = 0;
let mut stride = 1;
for i in (0..array.ndim()).rev() {
orig_linear_idx += orig_coords[i] * stride;
stride *= shape[i];
}
result_data.push(array.to_vec()[orig_linear_idx].clone());
}
Ok(Array::from_vec(result_data).reshape(&out_shape))
}
}
}
pub fn take_along_axis<T: Clone + ToString>(
array: &Array<T>,
indices: &Array<usize>,
axis: usize,
) -> Result<Array<T>> {
if axis >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array with {} dimensions",
axis,
array.ndim()
)));
}
let indices_slice = indices.array().as_slice().ok_or_else(|| {
NumRs2Error::InvalidOperation("indices array should be contiguous".to_string())
})?;
let array_shape = array.shape();
let indices_shape = indices.shape();
let axis_size = array_shape[axis];
for (i, (&a_dim, &i_dim)) in array_shape.iter().zip(indices_shape.iter()).enumerate() {
if i != axis && a_dim != i_dim {
return Err(NumRs2Error::ShapeMismatch {
expected: array_shape.clone(),
actual: indices_shape.clone(),
});
}
}
let result_shape = indices_shape.clone();
let mut result_data = Vec::with_capacity(indices.size());
for (flat_idx, &idx_value) in indices_slice.iter().enumerate() {
if idx_value >= axis_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for axis with size {}",
idx_value, axis_size
)));
}
let mut multi_idx = Vec::with_capacity(indices_shape.len());
let mut temp = flat_idx;
for &dim in indices_shape.iter().rev() {
multi_idx.insert(0, temp % dim);
temp /= dim;
}
multi_idx[axis] = idx_value;
let value = array
.array()
.get(IxDyn(&multi_idx))
.ok_or_else(|| {
NumRs2Error::IndexOutOfBounds(
"multi_idx should be valid as index was validated".to_string(),
)
})?
.clone();
result_data.push(value);
}
Ok(Array::from_vec(result_data).reshape(&result_shape))
}
pub fn put_along_axis<T: Clone + ToString>(
array: &mut Array<T>,
indices: &Array<usize>,
values: &Array<T>,
axis: usize,
) -> Result<()> {
if axis >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array with {} dimensions",
axis,
array.ndim()
)));
}
let indices_slice = indices.array().as_slice().ok_or_else(|| {
NumRs2Error::InvalidOperation("indices array should be contiguous".to_string())
})?;
let array_shape = array.shape();
let indices_shape = indices.shape();
let values_shape = values.shape();
let axis_size = array_shape[axis];
if indices_shape != values_shape {
return Err(NumRs2Error::ShapeMismatch {
expected: indices_shape.clone(),
actual: values_shape.clone(),
});
}
for (i, (&a_dim, &i_dim)) in array_shape.iter().zip(indices_shape.iter()).enumerate() {
if i != axis && a_dim != i_dim {
return Err(NumRs2Error::ShapeMismatch {
expected: array_shape.clone(),
actual: indices_shape.clone(),
});
}
}
let values_data = values.to_vec();
for (flat_idx, &idx_value) in indices_slice.iter().enumerate() {
if idx_value >= axis_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for axis with size {}",
idx_value, axis_size
)));
}
let mut multi_idx = Vec::with_capacity(indices_shape.len());
let mut temp = flat_idx;
for &dim in indices_shape.iter().rev() {
multi_idx.insert(0, temp % dim);
temp /= dim;
}
multi_idx[axis] = idx_value;
array.set(&multi_idx, values_data[flat_idx].clone())?;
}
Ok(())
}
pub fn extract<T: Clone + ToString, U: Clone + ToString>(
array: &Array<T>,
condition: &Array<U>,
) -> Result<Array<T>> {
if array.shape() != condition.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: array.shape(),
actual: condition.shape(),
});
}
let cond_slice = condition.array().as_slice().ok_or_else(|| {
NumRs2Error::InvalidOperation("condition array should be contiguous".to_string())
})?;
for i in 0..condition.size() {
let val_str = cond_slice[i].to_string();
if val_str != "true" && val_str != "false" {
return Err(NumRs2Error::InvalidOperation(
"Condition must contain boolean values".to_string(),
));
}
}
let array_data = array.to_vec();
let condition_data: Vec<bool> = condition
.to_vec()
.iter()
.map(|x| x.to_string() == "true")
.collect();
let mut result = Vec::new();
for (i, &cond) in condition_data.iter().enumerate() {
if cond {
result.push(array_data[i].clone());
}
}
Ok(Array::from_vec(result))
}