use super::advanced_ops::{ArrayView, IndexSpec, ResolvedIndex, Shape};
use crate::error::{NumRs2Error, Result};
use crate::traits::NumericElement;
#[derive(Debug, Clone)]
pub struct FancyIndexConfig {
pub enable_bounds_checking: bool,
pub enable_index_validation: bool,
pub max_temp_memory: usize,
pub enable_parallel: bool,
}
impl Default for FancyIndexConfig {
fn default() -> Self {
Self {
enable_bounds_checking: true,
enable_index_validation: true,
max_temp_memory: 100_000_000, enable_parallel: true,
}
}
}
#[repr(align(64))]
pub struct FancyIndexEngine {
config: FancyIndexConfig,
}
impl Default for FancyIndexEngine {
fn default() -> Self {
Self::new(FancyIndexConfig::default())
}
}
impl FancyIndexEngine {
pub fn new(config: FancyIndexConfig) -> Self {
Self { config }
}
pub fn index_with_arrays<T>(
&self,
array: &ArrayView<T>,
indices: &[Vec<usize>],
) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
if indices.len() != array.shape().ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Number of index arrays ({}) must match array dimensions ({})",
indices.len(),
array.shape().ndim()
)));
}
let output_length = indices[0].len();
for (i, index_array) in indices.iter().enumerate() {
if index_array.len() != output_length {
return Err(NumRs2Error::DimensionMismatch(format!(
"Index array {} has length {}, expected {}",
i,
index_array.len(),
output_length
)));
}
}
if self.config.enable_bounds_checking {
for (axis, index_array) in indices.iter().enumerate() {
let axis_size = array.shape().dims[axis];
for &idx in index_array {
if idx >= axis_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for axis {} of size {}",
idx, axis, axis_size
)));
}
}
}
}
let mut result = Vec::with_capacity(output_length);
for i in 0..output_length {
let multi_index: Vec<usize> = indices.iter().map(|arr| arr[i]).collect();
let element = array.get(&multi_index)?;
result.push(*element);
}
Ok(result)
}
pub fn index_with_boolean<T>(&self, array: &ArrayView<T>, mask: &[bool]) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
if mask.len() != array.shape().size() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Boolean mask length ({}) must match array size ({})",
mask.len(),
array.shape().size()
)));
}
let mut result = Vec::new();
for (i, &include) in mask.iter().enumerate() {
if include {
let multi_index = self.flat_to_multi_index(i, &array.shape().dims);
let element = array.get(&multi_index)?;
result.push(*element);
}
}
Ok(result)
}
pub fn advanced_index<T>(
&self,
array: &ArrayView<T>,
indices: &[IndexSpec],
) -> Result<FancyIndexResult<T>>
where
T: NumericElement + Copy,
{
let mut processed_indices = Vec::new();
let mut new_axes = Vec::new();
let mut output_shape_dims = Vec::new();
let mut axis = 0;
let mut ellipsis_used = false;
for (spec_idx, index_spec) in indices.iter().enumerate() {
match index_spec {
IndexSpec::Ellipsis => {
if ellipsis_used {
return Err(NumRs2Error::InvalidOperation(
"Only one ellipsis allowed".to_string(),
));
}
ellipsis_used = true;
let remaining_specs = indices.len() - spec_idx - 1;
let axes_to_add = array.shape().ndim().saturating_sub(remaining_specs);
for _ in 0..axes_to_add {
if axis < array.shape().ndim() {
processed_indices.push(ProcessedIndex::FullSlice(axis));
output_shape_dims.push(array.shape().dims[axis]);
axis += 1;
}
}
}
IndexSpec::NewAxis => {
new_axes.push(output_shape_dims.len());
output_shape_dims.push(1);
}
_ => {
if axis >= array.shape().ndim() {
return Err(NumRs2Error::DimensionMismatch(
"Too many indices for array".to_string(),
));
}
let resolved = index_spec.resolve(array.shape().dims[axis])?;
match resolved {
ResolvedIndex::Single(idx) => {
processed_indices.push(ProcessedIndex::Single(axis, idx));
}
ResolvedIndex::Multiple(idx_vec) => {
processed_indices.push(ProcessedIndex::Multiple(axis, idx_vec.clone()));
output_shape_dims.push(idx_vec.len());
}
}
axis += 1;
}
}
}
while axis < array.shape().ndim() {
processed_indices.push(ProcessedIndex::FullSlice(axis));
output_shape_dims.push(array.shape().dims[axis]);
axis += 1;
}
let data = self.extract_data(array, &processed_indices)?;
let output_shape = Shape::new(output_shape_dims);
Ok(FancyIndexResult {
data,
shape: output_shape,
new_axes,
})
}
pub fn set_with_arrays<T>(
&self,
array: &mut [T],
array_shape: &Shape,
indices: &[Vec<usize>],
values: &[T],
) -> Result<()>
where
T: NumericElement + Copy,
{
if indices.len() != array_shape.ndim() {
return Err(NumRs2Error::DimensionMismatch(
"Number of index arrays must match array dimensions".to_string(),
));
}
let output_length = indices[0].len();
if values.len() != output_length {
return Err(NumRs2Error::DimensionMismatch(
"Values array length must match index length".to_string(),
));
}
if self.config.enable_bounds_checking {
for (axis, index_array) in indices.iter().enumerate() {
let axis_size = array_shape.dims[axis];
for &idx in index_array {
if idx >= axis_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for axis {} of size {}",
idx, axis, axis_size
)));
}
}
}
}
let strides = array_shape.c_strides();
for i in 0..output_length {
let multi_index: Vec<usize> = indices.iter().map(|arr| arr[i]).collect();
let flat_index = self.multi_to_flat_index(&multi_index, &strides);
array[flat_index] = values[i];
}
Ok(())
}
pub fn set_with_boolean<T>(
&self,
array: &mut [T],
_array_shape: &Shape,
mask: &[bool],
value: T,
) -> Result<()>
where
T: NumericElement + Copy,
{
if mask.len() != array.len() {
return Err(NumRs2Error::DimensionMismatch(
"Boolean mask length must match array size".to_string(),
));
}
for (i, &include) in mask.iter().enumerate() {
if include {
array[i] = value;
}
}
Ok(())
}
pub fn where_condition<T, F>(&self, array: &ArrayView<T>, condition: F) -> Result<Vec<bool>>
where
T: NumericElement + Copy,
F: Fn(T) -> bool,
{
let mut mask = Vec::with_capacity(array.shape().size());
let mut indices = vec![0; array.shape().ndim()];
loop {
if let Ok(element) = array.get(&indices) {
mask.push(condition(*element));
}
let mut carry = 1;
for i in (0..indices.len()).rev() {
indices[i] += carry;
if indices[i] < array.shape().dims[i] {
carry = 0;
break;
} else {
indices[i] = 0;
carry = 1;
}
}
if carry == 1 {
break;
}
}
Ok(mask)
}
pub fn nonzero<T, F>(&self, array: &ArrayView<T>, condition: F) -> Result<Vec<Vec<usize>>>
where
T: NumericElement + Copy,
F: Fn(T) -> bool,
{
let mut result_indices = vec![Vec::new(); array.shape().ndim()];
let mut indices = vec![0; array.shape().ndim()];
loop {
if let Ok(element) = array.get(&indices) {
if condition(*element) {
for (axis, &idx) in indices.iter().enumerate() {
result_indices[axis].push(idx);
}
}
}
let mut carry = 1;
for i in (0..indices.len()).rev() {
indices[i] += carry;
if indices[i] < array.shape().dims[i] {
carry = 0;
break;
} else {
indices[i] = 0;
carry = 1;
}
}
if carry == 1 {
break;
}
}
Ok(result_indices)
}
pub fn take<T>(&self, array: &ArrayView<T>, indices: &[usize], axis: usize) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
if axis >= array.shape().ndim() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Axis {} is out of bounds for array of dimension {}",
axis,
array.shape().ndim()
)));
}
let axis_size = array.shape().dims[axis];
if self.config.enable_bounds_checking {
for &idx in indices {
if idx >= axis_size {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} is out of bounds for axis of size {}",
idx, axis_size
)));
}
}
}
let mut result = Vec::new();
let mut base_indices = vec![0; array.shape().ndim()];
let indices_vec = indices.to_vec();
self.iterate_other_axes(
array.shape(),
axis,
&mut base_indices,
0,
&mut |current_indices| {
for &take_idx in &indices_vec {
let mut temp_indices = current_indices.to_vec();
temp_indices[axis] = take_idx;
if let Ok(element) = array.get(&temp_indices) {
result.push(*element);
}
}
},
);
Ok(result)
}
pub fn choose<T>(&self, choices: &[&ArrayView<T>], index_array: &[usize]) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
if choices.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"No choice arrays provided".to_string(),
));
}
let reference_shape = choices[0].shape();
for (i, choice) in choices.iter().enumerate() {
if choice.shape() != reference_shape {
return Err(NumRs2Error::DimensionMismatch(format!(
"Choice array {} has different shape than reference",
i
)));
}
}
if index_array.len() != reference_shape.size() {
return Err(NumRs2Error::DimensionMismatch(
"Index array length must match choice array size".to_string(),
));
}
let mut result = Vec::with_capacity(index_array.len());
for (flat_idx, &choice_idx) in index_array.iter().enumerate() {
if choice_idx >= choices.len() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Choice index {} is out of bounds for {} choices",
choice_idx,
choices.len()
)));
}
let multi_index = self.flat_to_multi_index(flat_idx, &reference_shape.dims);
let element = choices[choice_idx].get(&multi_index)?;
result.push(*element);
}
Ok(result)
}
fn flat_to_multi_index(&self, flat_index: usize, shape: &[usize]) -> Vec<usize> {
let mut indices = Vec::with_capacity(shape.len());
let mut remaining = flat_index;
for &dim_size in shape.iter().rev() {
indices.push(remaining % dim_size);
remaining /= dim_size;
}
indices.reverse();
indices
}
fn multi_to_flat_index(&self, multi_index: &[usize], strides: &[usize]) -> usize {
multi_index
.iter()
.zip(strides.iter())
.map(|(&idx, &stride)| idx * stride)
.sum()
}
fn extract_data<T>(
&self,
array: &ArrayView<T>,
processed_indices: &[ProcessedIndex],
) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
let mut result = Vec::new();
let mut current_indices = vec![0; array.shape().ndim()];
self.extract_recursive(
array,
processed_indices,
0,
&mut current_indices,
&mut result,
)?;
Ok(result)
}
#[allow(clippy::only_used_in_recursion)]
fn extract_recursive<T>(
&self,
array: &ArrayView<T>,
processed_indices: &[ProcessedIndex],
depth: usize,
current_indices: &mut [usize],
result: &mut Vec<T>,
) -> Result<()>
where
T: NumericElement + Copy,
{
if depth >= processed_indices.len() {
let element = array.get(current_indices)?;
result.push(*element);
return Ok(());
}
match &processed_indices[depth] {
ProcessedIndex::Single(axis, idx) => {
current_indices[*axis] = *idx;
self.extract_recursive(
array,
processed_indices,
depth + 1,
current_indices,
result,
)?;
}
ProcessedIndex::Multiple(axis, indices) => {
for &idx in indices {
current_indices[*axis] = idx;
self.extract_recursive(
array,
processed_indices,
depth + 1,
current_indices,
result,
)?;
}
}
ProcessedIndex::FullSlice(axis) => {
for idx in 0..array.shape().dims[*axis] {
current_indices[*axis] = idx;
self.extract_recursive(
array,
processed_indices,
depth + 1,
current_indices,
result,
)?;
}
}
}
Ok(())
}
#[allow(clippy::only_used_in_recursion)]
fn iterate_other_axes<F>(
&self,
shape: &Shape,
skip_axis: usize,
indices: &mut [usize],
current_axis: usize,
callback: &mut F,
) where
F: FnMut(&[usize]),
{
if current_axis >= shape.ndim() {
callback(indices);
return;
}
if current_axis == skip_axis {
self.iterate_other_axes(shape, skip_axis, indices, current_axis + 1, callback);
} else {
for i in 0..shape.dims[current_axis] {
indices[current_axis] = i;
self.iterate_other_axes(shape, skip_axis, indices, current_axis + 1, callback);
}
}
}
}
#[derive(Debug, Clone)]
enum ProcessedIndex {
Single(usize, usize), Multiple(usize, Vec<usize>), FullSlice(usize), }
#[derive(Debug, Clone)]
pub struct FancyIndexResult<T> {
pub data: Vec<T>,
pub shape: Shape,
pub new_axes: Vec<usize>,
}
impl<T> FancyIndexResult<T> {
pub fn to_view(&self) -> Result<ArrayView<'_, T>> {
ArrayView::from_data(&self.data, self.shape.clone())
}
}
pub struct SpecializedIndexing;
impl SpecializedIndexing {
pub fn index_with_coordinates<T>(
array: &ArrayView<T>,
coordinates: &[Vec<usize>],
) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
if coordinates.len() != array.shape().ndim() {
return Err(NumRs2Error::DimensionMismatch(
"Number of coordinate arrays must match array dimensions".to_string(),
));
}
let mut result = Vec::new();
let total_combinations: usize = coordinates.iter().map(|c| c.len()).product();
for combination_idx in 0..total_combinations {
let mut multi_index = Vec::with_capacity(array.shape().ndim());
let mut remaining = combination_idx;
for coord_array in coordinates.iter().rev() {
let coord_idx = remaining % coord_array.len();
multi_index.push(coord_array[coord_idx]);
remaining /= coord_array.len();
}
multi_index.reverse();
let element = array.get(&multi_index)?;
result.push(*element);
}
Ok(result)
}
pub fn meshgrid_index<T>(array: &ArrayView<T>, grid_indices: &[Vec<usize>]) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
Self::index_with_coordinates(array, grid_indices)
}
pub fn multi_boolean_index<T>(
array: &ArrayView<T>,
conditions: &[Vec<bool>],
combine_op: BooleanCombineOp,
) -> Result<Vec<T>>
where
T: NumericElement + Copy,
{
if conditions.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"No conditions provided".to_string(),
));
}
let array_size = array.shape().size();
for (i, condition) in conditions.iter().enumerate() {
if condition.len() != array_size {
return Err(NumRs2Error::DimensionMismatch(format!(
"Condition {} length doesn't match array size",
i
)));
}
}
let mut combined_mask = vec![true; array_size];
for condition in conditions {
for (i, &cond_val) in condition.iter().enumerate() {
let mask_val = combined_mask[i];
combined_mask[i] = match combine_op {
BooleanCombineOp::And => mask_val && cond_val,
BooleanCombineOp::Or => mask_val || cond_val,
BooleanCombineOp::Xor => mask_val ^ cond_val,
};
}
}
let mut result = Vec::new();
let engine = FancyIndexEngine::default();
for (i, &include) in combined_mask.iter().enumerate() {
if include {
let multi_index = engine.flat_to_multi_index(i, &array.shape().dims);
let element = array.get(&multi_index)?;
result.push(*element);
}
}
Ok(result)
}
}
#[derive(Debug, Clone, Copy)]
pub enum BooleanCombineOp {
And,
Or,
Xor,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::arrays::advanced_ops::{ArrayView, Shape};
#[test]
fn test_fancy_index_engine_creation() {
let config = FancyIndexConfig::default();
let _engine = FancyIndexEngine::new(config);
let data = vec![1, 2, 3, 4, 5, 6];
let shape = Shape::from_2d(2, 3);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
assert_eq!(view.shape().size(), 6);
}
#[test]
fn test_fancy_indexing_with_arrays() {
let _engine = FancyIndexEngine::default();
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let shape = Shape::new(vec![3, 3]);
let _view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let _row_indices = vec![0, 2, 1];
let _col_indices = vec![1, 0, 2];
let _indices = [_row_indices, _col_indices];
}
#[test]
fn test_boolean_indexing() {
let engine = FancyIndexEngine::default();
let data = vec![1, 2, 3, 4, 5, 6];
let shape = Shape::from_1d(6);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let mask = vec![true, false, true, false, true, false];
let result = engine
.index_with_boolean(&view, &mask)
.expect("test: operation should succeed");
assert_eq!(result, vec![1, 3, 5]);
}
#[test]
fn test_where_condition() {
let engine = FancyIndexEngine::default();
let data = vec![1, 2, 3, 4, 5, 6];
let shape = Shape::from_1d(6);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let mask = engine
.where_condition(&view, |x| x > 3)
.expect("test: operation should succeed");
assert_eq!(mask, vec![false, false, false, true, true, true]);
}
#[test]
fn test_nonzero() {
let engine = FancyIndexEngine::default();
let data = vec![0, 1, 0, 2, 0, 3];
let shape = Shape::from_2d(2, 3);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let indices = engine
.nonzero(&view, |x| x != 0)
.expect("test: operation should succeed");
assert_eq!(indices[0], vec![0, 1, 1]); assert_eq!(indices[1], vec![1, 0, 2]); }
#[test]
fn test_take_along_axis() {
let engine = FancyIndexEngine::default();
let data = vec![1, 2, 3, 4, 5, 6];
let shape = Shape::from_2d(2, 3);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let indices = vec![2, 0, 1]; let result = engine
.take(&view, &indices, 1)
.expect("test: operation should succeed");
assert_eq!(result, vec![3, 1, 2, 6, 4, 5]);
}
#[test]
fn test_choose() {
let engine = FancyIndexEngine::default();
let data1 = vec![1, 2, 3];
let data2 = vec![10, 20, 30];
let data3 = vec![100, 200, 300];
let shape = Shape::from_1d(3);
let view1 =
ArrayView::from_data(&data1, shape.clone()).expect("test: operation should succeed");
let view2 =
ArrayView::from_data(&data2, shape.clone()).expect("test: operation should succeed");
let view3 = ArrayView::from_data(&data3, shape).expect("test: operation should succeed");
let choices = vec![&view1, &view2, &view3];
let index_array = vec![0, 2, 1];
let result = engine
.choose(&choices, &index_array)
.expect("test: operation should succeed");
assert_eq!(result, vec![1, 200, 30]);
}
#[test]
fn test_advanced_indexing() {
let engine = FancyIndexEngine::default();
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
let shape = Shape::new(vec![3, 4]);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let indices = vec![
IndexSpec::Int(1), IndexSpec::Slice(Some(1), Some(3), Some(1)), ];
let result = engine
.advanced_index(&view, &indices)
.expect("test: operation should succeed");
assert_eq!(result.data, vec![6, 7]); assert_eq!(result.shape.dims, vec![2]); }
#[test]
fn test_coordinate_indexing() {
let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
let shape = Shape::new(vec![3, 3]);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let coordinates = vec![vec![0, 2], vec![1, 0]];
let result = SpecializedIndexing::index_with_coordinates(&view, &coordinates)
.expect("test: operation should succeed");
assert_eq!(result, vec![2, 1, 8, 7]); }
#[test]
fn test_multi_boolean_indexing() {
let data = vec![1, 2, 3, 4, 5, 6];
let shape = Shape::from_1d(6);
let view = ArrayView::from_data(&data, shape).expect("test: operation should succeed");
let condition1 = vec![true, true, false, true, false, true];
let condition2 = vec![false, true, true, true, true, false];
let conditions = vec![condition1, condition2];
let result =
SpecializedIndexing::multi_boolean_index(&view, &conditions, BooleanCombineOp::And)
.expect("test: operation should succeed");
assert_eq!(result, vec![2, 4]); }
}