use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Zero;
pub fn compress<T: Clone + Zero>(
array: &Array<T>,
condition: &Array<bool>,
axis: Option<usize>,
) -> Result<Array<T>> {
match axis {
None => {
let flat = array.to_vec();
let cond_flat = condition.to_vec();
if flat.len() != cond_flat.len() {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![flat.len()],
actual: vec![cond_flat.len()],
});
}
let compressed: Vec<T> = flat
.into_iter()
.zip(cond_flat)
.filter_map(|(val, cond)| if cond { Some(val) } else { None })
.collect();
Ok(Array::from_vec(compressed))
}
Some(ax) => {
let shape = array.shape();
if ax >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
ax,
shape.len()
)));
}
let cond_vec = condition.to_vec();
if cond_vec.len() != shape[ax] {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![shape[ax]],
actual: vec![cond_vec.len()],
});
}
let indices: Vec<usize> = cond_vec
.into_iter()
.enumerate()
.filter_map(|(i, cond)| if cond { Some(i) } else { None })
.collect();
let mut new_shape = shape.clone();
new_shape[ax] = indices.len();
if indices.is_empty() {
return Ok(Array::zeros(&new_shape));
}
let mut result_data = Vec::with_capacity(new_shape.iter().product());
let mut current_indices = vec![0; shape.len()];
let total_elements: usize = shape.iter().product();
for _ in 0..total_elements {
if indices.contains(¤t_indices[ax]) {
let value = array.get(¤t_indices)?;
result_data.push(value);
}
let mut carry = true;
for dim in (0..shape.len()).rev() {
if carry {
current_indices[dim] += 1;
carry = current_indices[dim] >= shape[dim];
if carry {
current_indices[dim] = 0;
}
}
}
}
Ok(Array::from_vec(result_data).reshape(&new_shape))
}
}
}
pub fn extract<T: Clone>(array: &Array<T>, condition: &Array<bool>) -> Result<Array<T>> {
if array.shape() != condition.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: array.shape(),
actual: condition.shape(),
});
}
let data = array.to_vec();
let cond_data = condition.to_vec();
let extracted: Vec<T> = data
.into_iter()
.zip(cond_data)
.filter_map(|(val, cond)| if cond { Some(val) } else { None })
.collect();
Ok(Array::from_vec(extracted))
}
pub fn place<T: Clone>(array: &mut Array<T>, mask: &Array<bool>, values: &[T]) -> Result<()> {
if array.shape() != mask.shape() {
return Err(NumRs2Error::ShapeMismatch {
expected: array.shape(),
actual: mask.shape(),
});
}
if values.is_empty() {
return Err(NumRs2Error::ValueError(
"values array cannot be empty".to_string(),
));
}
let mask_data = mask.to_vec();
let num_true = mask_data.iter().filter(|&&x| x).count();
if num_true == 0 {
return Ok(()); }
let array_data = array
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
let mut value_idx = 0;
for (i, &is_true) in mask_data.iter().enumerate() {
if is_true {
array_data[i] = values[value_idx % values.len()].clone();
value_idx += 1;
}
}
Ok(())
}
pub fn put<T: Clone>(array: &mut Array<T>, indices: &Array<usize>, values: &[T]) -> Result<()> {
if values.is_empty() {
return Err(NumRs2Error::ValueError(
"values array cannot be empty".to_string(),
));
}
let indices_vec = indices.to_vec();
let array_len = array.size();
for &idx in &indices_vec {
if idx >= array_len {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"index {} is out of bounds for array of size {}",
idx, array_len
)));
}
}
let array_data = array
.array_mut()
.as_slice_mut()
.ok_or_else(|| NumRs2Error::InvalidOperation("Failed to get mutable slice".into()))?;
for (i, &idx) in indices_vec.iter().enumerate() {
array_data[idx] = values[i % values.len()].clone();
}
Ok(())
}
pub fn putmask<T: Clone>(
array: &mut Array<T>,
mask: &Array<bool>,
values: &Array<T>,
) -> Result<()> {
place(array, mask, &values.to_vec())
}
pub fn take_along_axis<T: Clone + Zero>(
array: &Array<T>,
indices: &Array<usize>,
axis: usize,
) -> Result<Array<T>> {
let arr_shape = array.shape();
let ind_shape = indices.shape();
if axis >= arr_shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
axis,
arr_shape.len()
)));
}
if arr_shape.len() != ind_shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"array and indices must have same number of dimensions, got {} and {}",
arr_shape.len(),
ind_shape.len()
)));
}
for (i, (&arr_dim, &ind_dim)) in arr_shape.iter().zip(ind_shape.iter()).enumerate() {
if i != axis && arr_dim != ind_dim {
return Err(NumRs2Error::ShapeMismatch {
expected: arr_shape.clone(),
actual: ind_shape.clone(),
});
}
}
let mut result_data = Vec::with_capacity(indices.size());
let mut current_pos = vec![0; ind_shape.len()];
let total_elements = indices.size();
for _ in 0..total_elements {
let idx = indices.get(¤t_pos)?;
if idx >= arr_shape[axis] {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"index {} is out of bounds for axis {} with size {}",
idx, axis, arr_shape[axis]
)));
}
let mut source_pos = current_pos.clone();
source_pos[axis] = idx;
let value = array.get(&source_pos)?;
result_data.push(value);
let mut carry = true;
for dim in (0..ind_shape.len()).rev() {
if carry {
current_pos[dim] += 1;
carry = current_pos[dim] >= ind_shape[dim];
if carry {
current_pos[dim] = 0;
}
}
}
}
Ok(Array::from_vec(result_data).reshape(&ind_shape))
}
pub fn apply_along_axis<T, U, F>(func: F, array: &Array<T>, axis: usize) -> Result<Array<U>>
where
T: Clone + Zero,
U: Clone + Zero,
F: Fn(&Array<T>) -> U,
{
let shape = array.shape();
if axis >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
axis,
shape.len()
)));
}
let mut out_shape = shape.clone();
out_shape.remove(axis);
if out_shape.is_empty() {
let result = func(array);
return Ok(Array::from_vec(vec![result]));
}
let mut result_data = Vec::new();
let n_slices: usize = out_shape.iter().product();
for slice_idx in 0..n_slices {
let mut slice_pos = vec![0; out_shape.len()];
let mut temp = slice_idx;
for i in (0..out_shape.len()).rev() {
slice_pos[i] = temp % out_shape[i];
temp /= out_shape[i];
}
let mut slice_data = Vec::with_capacity(shape[axis]);
for i in 0..shape[axis] {
let mut full_pos = Vec::with_capacity(shape.len());
let mut slice_dim = 0;
for dim in 0..shape.len() {
if dim == axis {
full_pos.push(i);
} else {
full_pos.push(slice_pos[slice_dim]);
slice_dim += 1;
}
}
slice_data.push(array.get(&full_pos)?);
}
let slice_array = Array::from_vec(slice_data);
let result = func(&slice_array);
result_data.push(result);
}
Ok(Array::from_vec(result_data).reshape(&out_shape))
}
pub fn apply_over_axes<T, F>(func: F, array: &Array<T>, axes: &[usize]) -> Result<Array<T>>
where
T: Clone + Zero,
F: Fn(&Array<T>) -> Result<Array<T>>,
{
let mut result = array.clone();
let mut sorted_axes = axes.to_vec();
sorted_axes.sort_by(|a, b| b.cmp(a));
for &axis in &sorted_axes {
if axis >= result.ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
axis,
result.ndim()
)));
}
let temp = func(&result)?;
if temp.ndim() != result.ndim() - 1 {
return Err(NumRs2Error::InvalidOperation(
"Function must reduce dimension by 1".to_string(),
));
}
result = temp;
}
Ok(result)
}
pub fn take<T: Clone + Zero>(
array: &Array<T>,
indices: &Array<usize>,
axis: Option<usize>,
) -> Result<Array<T>> {
match axis {
None => {
let flat = array.to_vec();
let idx_vec = indices.to_vec();
for &idx in &idx_vec {
if idx >= flat.len() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"index {} is out of bounds for flattened array of size {}",
idx,
flat.len()
)));
}
}
let result: Vec<T> = idx_vec.iter().map(|&idx| flat[idx].clone()).collect();
Ok(Array::from_vec(result).reshape(&indices.shape()))
}
Some(ax) => {
let shape = array.shape();
if ax >= shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"axis {} is out of bounds for array of dimension {}",
ax,
shape.len()
)));
}
let idx_vec = indices.to_vec();
for &idx in &idx_vec {
if idx >= shape[ax] {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"index {} is out of bounds for axis {} with size {}",
idx, ax, shape[ax]
)));
}
}
let mut result_shape = shape.clone();
result_shape[ax] = idx_vec.len();
let mut result_data = Vec::with_capacity(result_shape.iter().product());
let mut current_pos = vec![0; shape.len()];
let total_out: usize = result_shape.iter().product();
for _ in 0..total_out {
let mut source_pos = current_pos.clone();
source_pos[ax] = idx_vec[current_pos[ax]];
let value = array.get(&source_pos)?;
result_data.push(value);
let mut carry = true;
for dim in (0..result_shape.len()).rev() {
if carry {
current_pos[dim] += 1;
carry = current_pos[dim] >= result_shape[dim];
if carry {
current_pos[dim] = 0;
}
}
}
}
Ok(Array::from_vec(result_data).reshape(&result_shape))
}
}
}
pub fn fancy_index<T: Clone + Zero>(
array: &Array<T>,
indices: &[Array<usize>],
) -> Result<Array<T>> {
let shape = array.shape();
if indices.is_empty() {
return Err(NumRs2Error::ValueError(
"indices array cannot be empty".to_string(),
));
}
if indices.len() != shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"number of index arrays ({}) must match array dimensions ({})",
indices.len(),
shape.len()
)));
}
let idx_shape = indices[0].shape();
for idx_arr in &indices[1..] {
if idx_arr.shape() != idx_shape {
return Err(NumRs2Error::ShapeMismatch {
expected: idx_shape.clone(),
actual: idx_arr.shape(),
});
}
}
let num_elements = indices[0].size();
let mut result_data = Vec::with_capacity(num_elements);
let idx_vecs: Vec<Vec<usize>> = indices.iter().map(|arr| arr.to_vec()).collect();
for i in 0..num_elements {
let mut coord = Vec::with_capacity(shape.len());
for (dim, idx_vec) in idx_vecs.iter().enumerate() {
let idx = idx_vec[i];
if idx >= shape[dim] {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"index {} is out of bounds for dimension {} with size {}",
idx, dim, shape[dim]
)));
}
coord.push(idx);
}
let value = array.get(&coord)?;
result_data.push(value);
}
Ok(Array::from_vec(result_data).reshape(&idx_shape))
}
pub fn boolean_index<T: Clone>(array: &Array<T>, mask: &Array<bool>) -> Result<Array<T>> {
extract(array, mask)
}
pub fn select<T: Clone>(
conditions: &[Array<bool>],
choices: &[Array<T>],
default: T,
) -> Result<Array<T>> {
if conditions.is_empty() {
return Err(NumRs2Error::ValueError(
"conditions array cannot be empty".to_string(),
));
}
if conditions.len() != choices.len() {
return Err(NumRs2Error::ValueError(format!(
"number of conditions ({}) must match number of choices ({})",
conditions.len(),
choices.len()
)));
}
let shape = conditions[0].shape();
for cond in &conditions[1..] {
if cond.shape() != shape {
return Err(NumRs2Error::ShapeMismatch {
expected: shape.clone(),
actual: cond.shape(),
});
}
}
for choice in choices {
if choice.shape() != shape {
return Err(NumRs2Error::ShapeMismatch {
expected: shape.clone(),
actual: choice.shape(),
});
}
}
let cond_vecs: Vec<Vec<bool>> = conditions.iter().map(|c| c.to_vec()).collect();
let choice_vecs: Vec<Vec<T>> = choices.iter().map(|c| c.to_vec()).collect();
let num_elements = conditions[0].size();
let mut result_data = Vec::with_capacity(num_elements);
for i in 0..num_elements {
let mut selected = false;
for (cond_vec, choice_vec) in cond_vecs.iter().zip(choice_vecs.iter()) {
if cond_vec[i] {
result_data.push(choice_vec[i].clone());
selected = true;
break;
}
}
if !selected {
result_data.push(default.clone());
}
}
Ok(Array::from_vec(result_data).reshape(&shape))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_take_1d() {
let arr = Array::from_vec(vec![10, 20, 30, 40, 50]);
let indices = Array::from_vec(vec![0, 2, 4, 1]);
let result = take(&arr, &indices, None).expect("operation should succeed");
assert_eq!(result.shape(), &[4]);
assert_eq!(result.to_vec(), vec![10, 30, 50, 20]);
}
#[test]
fn test_take_2d_no_axis() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let indices = Array::from_vec(vec![0, 3, 5]);
let result = take(&arr, &indices, None).expect("operation should succeed");
assert_eq!(result.shape(), &[3]);
assert_eq!(result.to_vec(), vec![1, 4, 6]);
}
#[test]
fn test_take_along_axis() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let indices = Array::from_vec(vec![2, 0, 1]);
let result = take(&arr, &indices, Some(1)).expect("operation should succeed");
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(result.to_vec(), vec![3, 1, 2, 6, 4, 5]);
}
#[test]
fn test_take_out_of_bounds() {
let arr = Array::from_vec(vec![1, 2, 3]);
let indices = Array::from_vec(vec![0, 5]);
let result = take(&arr, &indices, None);
assert!(result.is_err());
}
#[test]
fn test_fancy_index_diagonal() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(&[3, 3]);
let row_idx = Array::from_vec(vec![0, 1, 2]);
let col_idx = Array::from_vec(vec![0, 1, 2]);
let result = fancy_index(&arr, &[row_idx, col_idx]).expect("operation should succeed");
assert_eq!(result.shape(), &[3]);
assert_eq!(result.to_vec(), vec![1, 5, 9]);
}
#[test]
fn test_fancy_index_arbitrary_coords() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5, 6, 7, 8, 9]).reshape(&[3, 3]);
let row_idx = Array::from_vec(vec![0, 2, 1]);
let col_idx = Array::from_vec(vec![2, 0, 1]);
let result = fancy_index(&arr, &[row_idx, col_idx]).expect("operation should succeed");
assert_eq!(result.shape(), &[3]);
assert_eq!(result.to_vec(), vec![3, 7, 5]);
}
#[test]
fn test_fancy_index_2d_indices() {
let arr = Array::from_vec(vec![10, 20, 30, 40, 50, 60]).reshape(&[2, 3]);
let row_idx = Array::from_vec(vec![0, 1, 0, 1]).reshape(&[2, 2]);
let col_idx = Array::from_vec(vec![0, 1, 2, 2]).reshape(&[2, 2]);
let result = fancy_index(&arr, &[row_idx, col_idx]).expect("operation should succeed");
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.to_vec(), vec![10, 50, 30, 60]);
}
#[test]
fn test_fancy_index_mismatched_shapes() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let row_idx = Array::from_vec(vec![0, 1]);
let col_idx = Array::from_vec(vec![0, 1, 2]);
let result = fancy_index(&arr, &[row_idx, col_idx]);
assert!(result.is_err());
}
#[test]
fn test_fancy_index_out_of_bounds() {
let arr = Array::from_vec(vec![1, 2, 3, 4]).reshape(&[2, 2]);
let row_idx = Array::from_vec(vec![0, 3]); let col_idx = Array::from_vec(vec![0, 0]);
let result = fancy_index(&arr, &[row_idx, col_idx]);
assert!(result.is_err());
}
#[test]
fn test_boolean_index_simple() {
let arr = Array::from_vec(vec![10, 20, 30, 40, 50]);
let mask = Array::from_vec(vec![true, false, true, false, true]);
let result = boolean_index(&arr, &mask).expect("operation should succeed");
assert_eq!(result.to_vec(), vec![10, 30, 50]);
}
#[test]
fn test_boolean_index_with_comparison() {
let arr = Array::from_vec(vec![1, 5, 3, 8, 2]);
let mask = arr.map(|x| x > 3);
let result = boolean_index(&arr, &mask).expect("operation should succeed");
assert_eq!(result.to_vec(), vec![5, 8]);
}
#[test]
fn test_boolean_index_2d() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5, 6]).reshape(&[2, 3]);
let mask = Array::from_vec(vec![true, false, true, false, true, false]).reshape(&[2, 3]);
let result = boolean_index(&arr, &mask).expect("operation should succeed");
assert_eq!(result.to_vec(), vec![1, 3, 5]);
}
#[test]
fn test_boolean_index_all_false() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5]);
let mask = Array::from_vec(vec![false; 5]);
let result = boolean_index(&arr, &mask).expect("operation should succeed");
assert_eq!(result.size(), 0);
}
#[test]
fn test_boolean_index_all_true() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5]);
let mask = Array::from_vec(vec![true; 5]);
let result = boolean_index(&arr, &mask).expect("operation should succeed");
assert_eq!(result.to_vec(), vec![1, 2, 3, 4, 5]);
}
#[test]
fn test_select_simple() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5]);
let cond1 = arr.map(|x| x < 3);
let choice1 = arr.map(|x| x * 10);
let cond2 = arr.map(|x| x >= 3);
let choice2 = arr.map(|x| x * 100);
let result =
select(&[cond1, cond2], &[choice1, choice2], 0).expect("operation should succeed");
assert_eq!(result.to_vec(), vec![10, 20, 300, 400, 500]);
}
#[test]
fn test_select_with_default() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5]);
let cond = arr.map(|x| x > 3);
let choice = arr.map(|x| x * 10);
let result = select(&[cond], &[choice], -1).expect("operation should succeed");
assert_eq!(result.to_vec(), vec![-1, -1, -1, 40, 50]);
}
#[test]
fn test_select_multiple_conditions() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5]);
let cond1 = arr.map(|x| x == 1);
let choice1 = Array::from_vec(vec![100, 100, 100, 100, 100]);
let cond2 = arr.map(|x| x == 3);
let choice2 = Array::from_vec(vec![300, 300, 300, 300, 300]);
let cond3 = arr.map(|x| x == 5);
let choice3 = Array::from_vec(vec![500, 500, 500, 500, 500]);
let result = select(&[cond1, cond2, cond3], &[choice1, choice2, choice3], 0)
.expect("operation should succeed");
assert_eq!(result.to_vec(), vec![100, 0, 300, 0, 500]);
}
#[test]
fn test_select_2d() {
let arr = Array::from_vec(vec![1, 2, 3, 4]).reshape(&[2, 2]);
let cond1 = arr.map(|x| x < 3);
let choice1 = arr.map(|x| x * 10);
let cond2 = arr.map(|x| x >= 3);
let choice2 = arr.map(|x| x * 100);
let result =
select(&[cond1, cond2], &[choice1, choice2], 0).expect("operation should succeed");
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result.to_vec(), vec![10, 20, 300, 400]);
}
#[test]
fn test_select_mismatched_lengths() {
let arr = Array::from_vec(vec![1, 2, 3]);
let cond1 = arr.map(|x| x < 2);
let choice1 = arr.map(|x| x * 10);
let cond2 = arr.map(|x| x >= 2);
let choice2 = arr.map(|x| x * 100);
let choice3 = arr.map(|x| x * 1000);
let result = select(&[cond1, cond2], &[choice1, choice2, choice3], 0);
assert!(result.is_err());
}
#[test]
fn test_select_mismatched_shapes() {
let arr = Array::from_vec(vec![1, 2, 3, 4]);
let cond1 = arr.map(|x| x < 3);
let choice1 = arr.map(|x| x * 10);
let cond2 = Array::from_vec(vec![true, false]); let choice2 = arr.map(|x| x * 100);
let result = select(&[cond1, cond2], &[choice1, choice2], 0);
assert!(result.is_err());
}
#[test]
fn test_combined_indexing_take_and_boolean() {
let arr = Array::from_vec(vec![5, 2, 8, 1, 9, 3]);
let indices = Array::from_vec(vec![0, 2, 4]);
let reordered = take(&arr, &indices, None).expect("operation should succeed");
let mask = reordered.map(|x| x > 7);
let result = boolean_index(&reordered, &mask).expect("operation should succeed");
assert_eq!(result.to_vec(), vec![8, 9]);
}
#[test]
fn test_take_empty_indices() {
let arr = Array::from_vec(vec![1, 2, 3, 4, 5]);
let indices: Array<usize> = Array::from_vec(vec![]);
let result = take(&arr, &indices, None).expect("operation should succeed");
assert_eq!(result.size(), 0);
}
#[test]
fn test_take_repeated_indices() {
let arr = Array::from_vec(vec![10, 20, 30]);
let indices = Array::from_vec(vec![0, 0, 1, 1, 2, 2]);
let result = take(&arr, &indices, None).expect("operation should succeed");
assert_eq!(result.to_vec(), vec![10, 10, 20, 20, 30, 30]);
}
}