use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Zero;
pub fn delete<T: Clone + Zero>(
array: &Array<T>,
indices: &[usize],
axis: Option<usize>,
) -> Result<Array<T>> {
match axis {
Some(ax) => {
if ax >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
ax,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[ax];
for &idx in indices {
if idx >= axis_size {
return Err(NumRs2Error::InvalidOperation(format!(
"Index {} out of bounds for axis {} with size {}",
idx, ax, axis_size
)));
}
}
let mut delete_indices = indices.to_vec();
delete_indices.sort_unstable();
delete_indices.dedup();
let mut new_shape = shape.clone();
new_shape[ax] = axis_size - delete_indices.len();
if new_shape[ax] == 0 {
return Ok(Array::zeros(&new_shape));
}
let mut result_data = Vec::with_capacity(new_shape.iter().product());
let mut strides = vec![1; array.ndim()];
for i in (0..array.ndim() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let total_size: usize = shape.iter().product();
for i in 0..total_size {
let mut indices_arr = vec![0; shape.len()];
let mut temp = i;
for j in 0..shape.len() {
indices_arr[j] = temp / strides[j];
temp %= strides[j];
}
let axis_pos = indices_arr[ax];
if !delete_indices.contains(&axis_pos) {
result_data.push(array.get(&indices_arr)?);
}
}
Ok(Array::from_vec(result_data).reshape(&new_shape))
}
None => {
let flat = array.to_vec();
let flat_size = flat.len();
for &idx in indices {
if idx >= flat_size {
return Err(NumRs2Error::InvalidOperation(format!(
"Index {} out of bounds for flattened array with size {}",
idx, flat_size
)));
}
}
let mut delete_indices = indices.to_vec();
delete_indices.sort_unstable();
delete_indices.dedup();
let mut result_data = Vec::with_capacity(flat_size - delete_indices.len());
let mut del_idx = 0;
for (i, val) in flat.iter().enumerate() {
if del_idx < delete_indices.len() && i == delete_indices[del_idx] {
del_idx += 1;
} else {
result_data.push(val.clone());
}
}
Ok(Array::from_vec(result_data))
}
}
}
pub fn insert<T: Clone + Zero>(
array: &Array<T>,
indices: &[usize],
values: &[T],
axis: Option<usize>,
) -> Result<Array<T>> {
match axis {
Some(ax) => {
if ax >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
ax,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[ax];
let mut sorted_indices: Vec<(usize, usize)> = indices
.iter()
.enumerate()
.map(|(i, &idx)| (idx, i))
.collect();
sorted_indices.sort_by_key(|&(idx, _)| idx);
let values_per_insertion = if values.len() == 1 {
1
} else if values.len() == indices.len() {
1
} else {
if !values.len().is_multiple_of(indices.len()) {
return Err(NumRs2Error::InvalidOperation(
"Values length must be 1, equal to indices length, or a multiple of indices length".into()
));
}
values.len() / indices.len()
};
let mut new_shape = shape.clone();
new_shape[ax] = axis_size + indices.len() * values_per_insertion;
let mut sub_size = 1;
for (i, &dim) in shape.iter().enumerate() {
if i != ax {
sub_size *= dim;
}
}
let mut result_data = Vec::with_capacity(new_shape.iter().product());
let mut src_pos = 0;
let mut insert_idx = 0;
for _new_pos in 0..new_shape[ax] {
let mut should_insert = false;
let mut which_insert = 0;
if insert_idx < sorted_indices.len() {
let (idx, orig_order) = sorted_indices[insert_idx];
if src_pos == idx {
should_insert = true;
which_insert = orig_order;
}
}
if should_insert {
for val_idx in 0..values_per_insertion {
for _sub_idx in 0..sub_size {
let value_idx = if values.len() == 1 {
0
} else if values.len() == indices.len() {
which_insert
} else {
which_insert * values_per_insertion + val_idx
};
result_data.push(values[value_idx].clone());
}
}
insert_idx += 1;
} else if src_pos < axis_size {
let mut indices_arr = vec![0; shape.len()];
for sub_idx in 0..sub_size {
let mut temp = sub_idx;
for i in (0..shape.len()).rev() {
if i == ax {
indices_arr[i] = src_pos;
} else {
let dim = shape[i];
if i < shape.len() - 1 {
indices_arr[i] = temp % dim;
temp /= dim;
} else {
indices_arr[i] = temp;
}
}
}
result_data.push(array.get(&indices_arr)?);
}
src_pos += 1;
}
}
Ok(Array::from_vec(result_data).reshape(&new_shape))
}
None => {
let flat = array.to_vec();
let flat_size = flat.len();
if indices.len() != values.len() && values.len() != 1 {
return Err(NumRs2Error::InvalidOperation(
"For flat insertion, values must have length 1 or match indices length".into(),
));
}
let mut insertions: Vec<(usize, T)> = Vec::new();
for (i, &idx) in indices.iter().enumerate() {
let val = if values.len() == 1 {
values[0].clone()
} else {
values[i].clone()
};
insertions.push((idx, val));
}
insertions.sort_by_key(|&(idx, _)| idx);
let mut result_data = Vec::with_capacity(flat_size + insertions.len());
let mut orig_idx = 0;
let mut insert_idx = 0;
for _pos in 0..flat_size + insertions.len() {
if insert_idx < insertions.len() && insertions[insert_idx].0 == orig_idx {
result_data.push(insertions[insert_idx].1.clone());
insert_idx += 1;
} else if orig_idx < flat_size {
result_data.push(flat[orig_idx].clone());
orig_idx += 1;
}
}
while insert_idx < insertions.len() {
result_data.push(insertions[insert_idx].1.clone());
insert_idx += 1;
}
Ok(Array::from_vec(result_data))
}
}
}
pub fn trim_zeros<T>(array: &Array<T>, trim: Option<&str>) -> Result<Array<T>>
where
T: Clone + Zero + PartialEq,
{
if array.ndim() != 1 {
return Err(NumRs2Error::InvalidOperation(
"trim_zeros requires a 1-D array".into(),
));
}
let data = array.to_vec();
if data.is_empty() {
return Ok(array.clone());
}
let trim_str = trim.unwrap_or("fb");
let trim_front = trim_str.contains('f');
let trim_back = trim_str.contains('b');
let mut start = 0;
let mut end = data.len();
if trim_front {
for (i, val) in data.iter().enumerate() {
if !val.is_zero() {
start = i;
break;
}
}
if start == 0 && data[0].is_zero() {
let all_zero = data.iter().all(|x| x.is_zero());
if all_zero {
return Ok(Array::from_vec(vec![]));
}
}
}
if trim_back {
for (i, val) in data.iter().enumerate().rev() {
if !val.is_zero() {
end = i + 1;
break;
}
}
}
if start >= end {
return Ok(Array::from_vec(vec![]));
}
let trimmed_data: Vec<T> = data[start..end].to_vec();
Ok(Array::from_vec(trimmed_data))
}
pub fn extract<T: Clone>(array: &Array<T>, condition: &Array<bool>) -> Result<Array<T>> {
if array.shape() != condition.shape() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array and condition must have the same shape, got {:?} and {:?}",
array.shape(),
condition.shape()
)));
}
let array_flat = array.to_vec();
let condition_flat = condition.to_vec();
let mut result_data = Vec::new();
for (val, &cond) in array_flat.iter().zip(condition_flat.iter()) {
if cond {
result_data.push(val.clone());
}
}
Ok(Array::from_vec(result_data))
}
pub fn place<T: Clone>(array: &mut Array<T>, mask: &Array<bool>, values: &[T]) -> Result<()> {
if array.shape() != mask.shape() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Array and mask must have the same shape, got {:?} and {:?}",
array.shape(),
mask.shape()
)));
}
let mask_flat = mask.to_vec();
let true_count = mask_flat.iter().filter(|&&x| x).count();
if values.len() != true_count {
return Err(NumRs2Error::InvalidOperation(format!(
"Number of values ({}) must match number of true elements in mask ({})",
values.len(),
true_count
)));
}
let array_slice = array
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
let mut value_idx = 0;
for (i, &mask_val) in mask_flat.iter().enumerate() {
if mask_val {
array_slice[i] = values[value_idx].clone();
value_idx += 1;
}
}
Ok(())
}
pub fn put<T: Clone>(array: &mut Array<T>, indices: &[usize], values: &[T]) -> Result<()> {
if values.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Values array cannot be empty".into(),
));
}
let array_slice = array
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
let array_size = array_slice.len();
for &idx in indices {
if idx >= array_size {
return Err(NumRs2Error::InvalidOperation(format!(
"Index {} out of bounds for array with size {}",
idx, array_size
)));
}
}
for (i, &idx) in indices.iter().enumerate() {
array_slice[idx] = values[i % values.len()].clone();
}
Ok(())
}
pub fn compress<T: Clone + Zero>(
array: &Array<T>,
condition: &Array<bool>,
axis: Option<usize>,
) -> Result<Array<T>> {
if condition.ndim() != 1 {
return Err(NumRs2Error::InvalidOperation(
"Condition must be a 1-D array".into(),
));
}
match axis {
Some(ax) => {
if ax >= array.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} out of bounds for array of dimension {}",
ax,
array.ndim()
)));
}
let shape = array.shape();
let axis_size = shape[ax];
if condition.size() != axis_size {
return Err(NumRs2Error::DimensionMismatch(format!(
"Condition length {} doesn't match axis {} size {}",
condition.size(),
ax,
axis_size
)));
}
let condition_vec = condition.to_vec();
let selected_indices: Vec<usize> = condition_vec
.iter()
.enumerate()
.filter_map(|(i, &val)| if val { Some(i) } else { None })
.collect();
if selected_indices.is_empty() {
let mut new_shape = shape.clone();
new_shape[ax] = 0;
return Ok(Array::from_vec(vec![]).reshape(&new_shape));
}
let mut new_shape = shape.clone();
new_shape[ax] = selected_indices.len();
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let total_size: usize = new_shape.iter().product();
let mut result_data = Vec::with_capacity(total_size);
for i in 0..total_size {
let mut indices_arr = vec![0; shape.len()];
let mut temp = i;
for j in 0..new_shape.len() {
indices_arr[j] = temp / strides[j];
temp %= strides[j];
}
if ax < indices_arr.len() && indices_arr[ax] < selected_indices.len() {
indices_arr[ax] = selected_indices[indices_arr[ax]];
result_data.push(array.get(&indices_arr)?);
}
}
Ok(Array::from_vec(result_data).reshape(&new_shape))
}
None => {
extract(array, condition)
}
}
}