use crate::array::Array;
use crate::error::{NumRs2Error, Result};
pub fn indices_grid<T: Clone + num_traits::Zero + num_traits::One + num_traits::NumCast>(
shape: &[usize],
) -> Result<Vec<Array<T>>> {
if shape.is_empty() {
return Ok(vec![]);
}
let mut result = Vec::with_capacity(shape.len());
for (i, &dim) in shape.iter().enumerate() {
let mut index_shape = vec![1; shape.len()];
index_shape[i] = dim;
let mut index_data = Vec::with_capacity(dim);
for j in 0..dim {
index_data.push(T::from(j).ok_or_else(|| {
NumRs2Error::InvalidOperation("usize should be castable to T".to_string())
})?);
}
let index_array = Array::from_vec(index_data).reshape(&index_shape);
result.push(index_array);
}
Ok(result)
}
pub fn mask_indices<F>(shape: &[usize], mask_fn: F) -> Result<Vec<Array<usize>>>
where
F: Fn(&[usize]) -> bool,
{
if shape.is_empty() {
return Ok(vec![]);
}
let total_elements: usize = shape.iter().product();
let mut indices_vec: Vec<Vec<usize>> = vec![Vec::new(); shape.len()];
let mut indices = vec![0; shape.len()];
for _ in 0..total_elements {
if mask_fn(&indices) {
for (dim, &idx) in indices.iter().enumerate() {
indices_vec[dim].push(idx);
}
}
let mut carry = true;
for dim in (0..shape.len()).rev() {
if carry {
indices[dim] += 1;
carry = indices[dim] >= shape[dim];
if carry {
indices[dim] = 0;
}
}
}
}
let result = indices_vec.into_iter().map(Array::from_vec).collect();
Ok(result)
}
pub fn ravel_multi_index(
multi_index: &[Array<usize>],
dims: &[usize],
mode: &str,
) -> Result<Array<usize>> {
if multi_index.len() != dims.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Number of index arrays ({}) does not match number of dimensions ({})",
multi_index.len(),
dims.len()
)));
}
if multi_index.is_empty() {
return Ok(Array::from_vec(vec![]));
}
let shape = multi_index[0].shape();
for (_i, arr) in multi_index.iter().enumerate().skip(1) {
if arr.shape() != shape {
return Err(NumRs2Error::ShapeMismatch {
expected: shape.clone(),
actual: arr.shape(),
});
}
}
let mut strides = vec![1; dims.len()];
for i in (0..dims.len() - 1).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
let size = multi_index[0].size();
let mut flat_indices = vec![0; size];
for i in 0..size {
let mut flat_idx = 0;
for (dim_idx, (indices_arr, &dim_size)) in multi_index.iter().zip(dims.iter()).enumerate() {
let idx = indices_arr.to_vec()[i];
let bounded_idx = match mode {
"raise" => {
if idx >= dim_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for dimension {} with size {}",
idx, dim_idx, dim_size
)));
}
idx
}
"wrap" => idx % dim_size,
"clip" => idx.min(dim_size - 1),
_ => {
return Err(NumRs2Error::InvalidOperation(format!(
"Invalid mode '{}'. Must be 'raise', 'wrap', or 'clip'",
mode
)));
}
};
flat_idx += bounded_idx * strides[dim_idx];
}
flat_indices[i] = flat_idx;
}
if multi_index[0].ndim() == 1 {
Ok(Array::from_vec(flat_indices))
} else {
Ok(Array::from_vec(flat_indices).reshape(&shape))
}
}
pub fn unravel_index(indices: &Array<usize>, dims: &[usize]) -> Result<Vec<Array<usize>>> {
if dims.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Cannot unravel indices for empty dimensions".to_string(),
));
}
let total_size: usize = dims.iter().product();
for &idx in indices.to_vec().iter() {
if idx >= total_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for array of size {}",
idx, total_size
)));
}
}
let mut strides = vec![1; dims.len()];
for i in (0..dims.len() - 1).rev() {
strides[i] = strides[i + 1] * dims[i + 1];
}
let flat_indices = indices.to_vec();
let mut multi_indices = vec![vec![0; flat_indices.len()]; dims.len()];
for (i, &flat_idx) in flat_indices.iter().enumerate() {
let mut remainder = flat_idx;
for (dim_idx, &stride) in strides.iter().enumerate() {
multi_indices[dim_idx][i] = remainder / stride;
remainder %= stride;
}
}
let shape = indices.shape();
let result: Vec<Array<usize>> = multi_indices
.into_iter()
.map(|indices| {
if shape.len() == 1 {
Array::from_vec(indices)
} else {
Array::from_vec(indices).reshape(&shape)
}
})
.collect();
Ok(result)
}
pub fn tril_indices(n: usize, k: isize, m: Option<usize>) -> Result<(Array<usize>, Array<usize>)> {
let m = m.unwrap_or(n);
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
for i in 0..n {
for j in 0..m {
if (j as isize) <= (i as isize + k) {
row_indices.push(i);
col_indices.push(j);
}
}
}
Ok((Array::from_vec(row_indices), Array::from_vec(col_indices)))
}
pub fn triu_indices(n: usize, k: isize, m: Option<usize>) -> Result<(Array<usize>, Array<usize>)> {
let m = m.unwrap_or(n);
let mut row_indices = Vec::new();
let mut col_indices = Vec::new();
for i in 0..n {
for j in 0..m {
if (j as isize) >= (i as isize + k) {
row_indices.push(i);
col_indices.push(j);
}
}
}
Ok((Array::from_vec(row_indices), Array::from_vec(col_indices)))
}
pub fn diag_indices(n: usize, ndim: Option<usize>) -> Result<Vec<Array<usize>>> {
let ndim = ndim.unwrap_or(2);
if ndim == 0 {
return Err(NumRs2Error::InvalidOperation(
"Number of dimensions must be at least 1".to_string(),
));
}
let mut result = Vec::with_capacity(ndim);
let diagonal_indices: Vec<usize> = (0..n).collect();
for _dim in 0..ndim {
result.push(Array::from_vec(diagonal_indices.clone()));
}
Ok(result)
}
pub fn diag_indices_from<T: Clone>(arr: &Array<T>) -> Result<Vec<Array<usize>>> {
let shape = arr.shape();
let ndim = arr.ndim();
if ndim == 0 {
return Err(NumRs2Error::InvalidOperation(
"Array must have at least 1 dimension".to_string(),
));
}
let min_dim = shape.iter().min().copied().unwrap_or(0);
diag_indices(min_dim, Some(ndim))
}
pub fn tril_indices_from<T: Clone>(
arr: &Array<T>,
k: Option<isize>,
) -> Result<(Array<usize>, Array<usize>)> {
let shape = arr.shape();
if shape.len() < 2 {
return Err(NumRs2Error::InvalidOperation(
"Array must be at least 2-dimensional".to_string(),
));
}
let n = shape[shape.len() - 2]; let m = shape[shape.len() - 1]; let k = k.unwrap_or(0);
tril_indices(n, k, Some(m))
}
pub fn triu_indices_from<T: Clone>(
arr: &Array<T>,
k: Option<isize>,
) -> Result<(Array<usize>, Array<usize>)> {
let shape = arr.shape();
if shape.len() < 2 {
return Err(NumRs2Error::InvalidOperation(
"Array must be at least 2-dimensional".to_string(),
));
}
let n = shape[shape.len() - 2]; let m = shape[shape.len() - 1]; let k = k.unwrap_or(0);
triu_indices(n, k, Some(m))
}