use scirs2_core::ndarray::{Array, Array1, ArrayD, Axis, Dimension, IxDyn};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub enum GatherScatterError {
OutOfBoundsIndex {
index: i64,
axis_len: usize,
axis: usize,
},
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
AxisOutOfRange { axis: usize, ndim: usize },
EmptyInput,
IndexRankMismatch {
input_ndim: usize,
index_ndim: usize,
},
KTooLarge { k: usize, axis_len: usize },
}
impl std::fmt::Display for GatherScatterError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GatherScatterError::OutOfBoundsIndex {
index,
axis_len,
axis,
} => write!(
f,
"index {index} is out of bounds for axis {axis} with length {axis_len}"
),
GatherScatterError::ShapeMismatch { expected, got } => {
write!(f, "shape mismatch: expected {expected:?}, got {got:?}")
}
GatherScatterError::AxisOutOfRange { axis, ndim } => {
write!(
f,
"axis {axis} is out of range for tensor with {ndim} dimensions"
)
}
GatherScatterError::EmptyInput => write!(f, "empty input"),
GatherScatterError::IndexRankMismatch {
input_ndim,
index_ndim,
} => write!(
f,
"index rank mismatch: input has {input_ndim} dims, indices have {index_ndim} dims"
),
GatherScatterError::KTooLarge { k, axis_len } => {
write!(f, "k={k} exceeds axis length {axis_len}")
}
}
}
}
impl std::error::Error for GatherScatterError {}
#[inline]
fn normalize_index(raw: i64, axis_len: usize, axis: usize) -> Result<usize, GatherScatterError> {
let len = axis_len as i64;
let normalized = if raw < 0 { len + raw } else { raw };
if normalized < 0 || normalized >= len {
return Err(GatherScatterError::OutOfBoundsIndex {
index: raw,
axis_len,
axis,
});
}
Ok(normalized as usize)
}
#[inline]
fn check_axis(axis: usize, ndim: usize) -> Result<(), GatherScatterError> {
if axis >= ndim {
return Err(GatherScatterError::AxisOutOfRange { axis, ndim });
}
Ok(())
}
pub fn gather(
input: &ArrayD<f64>,
indices: &[i64],
axis: usize,
) -> Result<ArrayD<f64>, GatherScatterError> {
let ndim = input.ndim();
check_axis(axis, ndim)?;
if indices.is_empty() {
return Err(GatherScatterError::EmptyInput);
}
let axis_len = input.shape()[axis];
let mut slices: Vec<ArrayD<f64>> = Vec::with_capacity(indices.len());
for &raw in indices {
let idx = normalize_index(raw, axis_len, axis)?;
let view = input.index_axis(Axis(axis), idx);
slices.push(view.to_owned().into_dyn());
}
let views: Vec<_> = slices.iter().map(|s| s.view()).collect();
let stacked = scirs2_core::ndarray::stack(Axis(axis), &views).map_err(|_| {
GatherScatterError::ShapeMismatch {
expected: input.shape().to_vec(),
got: vec![],
}
})?;
Ok(stacked)
}
pub fn gather_nd(
input: &ArrayD<f64>,
indices: &ArrayD<i64>,
axis: usize,
) -> Result<ArrayD<f64>, GatherScatterError> {
let ndim = input.ndim();
check_axis(axis, ndim)?;
if indices.ndim() != ndim {
return Err(GatherScatterError::IndexRankMismatch {
input_ndim: ndim,
index_ndim: indices.ndim(),
});
}
for d in 0..ndim {
if d != axis && input.shape()[d] != indices.shape()[d] {
return Err(GatherScatterError::ShapeMismatch {
expected: input.shape().to_vec(),
got: indices.shape().to_vec(),
});
}
}
let axis_len = input.shape()[axis];
let output_shape = indices.shape().to_vec();
let total = output_shape.iter().product::<usize>();
let mut out_flat: Vec<f64> = Vec::with_capacity(total);
for (multi_idx, &raw_idx) in indices.indexed_iter() {
let gather_pos = normalize_index(raw_idx, axis_len, axis)?;
let mut input_idx: Vec<usize> = multi_idx.slice().to_vec();
input_idx[axis] = gather_pos;
let val =
input
.get(IxDyn(&input_idx))
.copied()
.ok_or(GatherScatterError::OutOfBoundsIndex {
index: raw_idx,
axis_len,
axis,
})?;
out_flat.push(val);
}
Array::from_shape_vec(IxDyn(&output_shape), out_flat).map_err(|_| {
GatherScatterError::ShapeMismatch {
expected: output_shape,
got: vec![],
}
})
}
fn scatter_generic<F>(
input: &ArrayD<f64>,
indices: &[i64],
axis: usize,
output_size: usize,
init_value: f64,
combine: F,
) -> Result<ArrayD<f64>, GatherScatterError>
where
F: Fn(f64, f64) -> f64,
{
let ndim = input.ndim();
check_axis(axis, ndim)?;
let in_axis_len = input.shape()[axis];
if indices.len() != in_axis_len {
return Err(GatherScatterError::ShapeMismatch {
expected: vec![in_axis_len],
got: vec![indices.len()],
});
}
let mut out_shape = input.shape().to_vec();
out_shape[axis] = output_size;
let out_total = out_shape.iter().product::<usize>();
let mut out_data: Vec<f64> = vec![init_value; out_total];
for (i, &raw) in indices.iter().enumerate() {
let dst = normalize_index(raw, output_size, axis)?;
let input_slice = input.index_axis(Axis(axis), i);
let output_slice_offset = compute_axis_offset(&out_shape, axis, dst);
for (flat_in, &val) in input_slice.iter().enumerate() {
let flat_out = output_slice_offset + slice_flat_index(&out_shape, axis, flat_in);
out_data[flat_out] = combine(out_data[flat_out], val);
}
}
Array::from_shape_vec(IxDyn(&out_shape), out_data).map_err(|_| {
GatherScatterError::ShapeMismatch {
expected: out_shape,
got: vec![],
}
})
}
fn compute_axis_offset(shape: &[usize], axis: usize, pos: usize) -> usize {
let stride: usize = shape[axis + 1..].iter().product();
pos * stride
}
fn slice_flat_index(shape: &[usize], axis: usize, flat_in: usize) -> usize {
let inner_size: usize = shape[axis + 1..].iter().product();
let outer_idx = flat_in / inner_size; let inner_idx = flat_in % inner_size;
let outer_stride = shape[axis] * inner_size;
outer_idx * outer_stride + inner_idx
}
pub fn scatter_add(
input: &ArrayD<f64>,
indices: &[i64],
axis: usize,
output_size: usize,
init_value: f64,
) -> Result<ArrayD<f64>, GatherScatterError> {
scatter_generic(input, indices, axis, output_size, init_value, |a, b| a + b)
}
pub fn scatter_max(
input: &ArrayD<f64>,
indices: &[i64],
axis: usize,
output_size: usize,
init_value: f64,
) -> Result<ArrayD<f64>, GatherScatterError> {
scatter_generic(input, indices, axis, output_size, init_value, f64::max)
}
pub fn scatter_min(
input: &ArrayD<f64>,
indices: &[i64],
axis: usize,
output_size: usize,
init_value: f64,
) -> Result<ArrayD<f64>, GatherScatterError> {
scatter_generic(input, indices, axis, output_size, init_value, f64::min)
}
pub fn top_k(
input: &ArrayD<f64>,
k: usize,
axis: usize,
largest: bool,
) -> Result<(ArrayD<f64>, ArrayD<i64>), GatherScatterError> {
let ndim = input.ndim();
check_axis(axis, ndim)?;
let axis_len = input.shape()[axis];
if k > axis_len {
return Err(GatherScatterError::KTooLarge { k, axis_len });
}
if input.is_empty() {
return Err(GatherScatterError::EmptyInput);
}
let mut out_shape = input.shape().to_vec();
out_shape[axis] = k;
let out_total = out_shape.iter().product::<usize>();
let mut val_flat: Vec<f64> = Vec::with_capacity(out_total);
let mut idx_flat: Vec<i64> = Vec::with_capacity(out_total);
let outer_size: usize = input.shape()[..axis].iter().product();
let inner_size: usize = input.shape()[axis + 1..].iter().product();
for outer in 0..outer_size {
for inner in 0..inner_size {
let mut pairs: Vec<(f64, usize)> = (0..axis_len)
.map(|a| {
let mut midx = vec![0usize; ndim];
let outer_stride_per_dim =
compute_outer_multi_index(input.shape(), axis, outer);
let inner_multi = compute_inner_multi_index(input.shape(), axis, inner);
midx[..axis].copy_from_slice(&outer_stride_per_dim[..axis]);
midx[axis] = a;
for d in (axis + 1)..ndim {
midx[d] = inner_multi[d - axis - 1];
}
let val = input.get(IxDyn(&midx)).copied().unwrap_or(f64::NAN);
(val, a)
})
.collect();
pairs.sort_by(|(a, _), (b, _)| {
if largest {
b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
} else {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
}
});
for (v, i) in pairs.iter().take(k) {
val_flat.push(*v);
idx_flat.push(*i as i64);
}
}
}
let val_array = reorder_topk_output(&val_flat, outer_size, k, inner_size, &out_shape)?;
let idx_array = reorder_topk_output_i64(&idx_flat, outer_size, k, inner_size, &out_shape)?;
Ok((val_array, idx_array))
}
fn reorder_topk_output(
data: &[f64],
outer: usize,
k: usize,
inner: usize,
shape: &[usize],
) -> Result<ArrayD<f64>, GatherScatterError> {
let total = outer * k * inner;
let mut out: Vec<f64> = vec![0.0; total];
for o in 0..outer {
for ki in 0..k {
for i in 0..inner {
let src = o * inner * k + i * k + ki;
let dst = o * k * inner + ki * inner + i;
out[dst] = data[src];
}
}
}
Array::from_shape_vec(IxDyn(shape), out).map_err(|_| GatherScatterError::ShapeMismatch {
expected: shape.to_vec(),
got: vec![],
})
}
fn reorder_topk_output_i64(
data: &[i64],
outer: usize,
k: usize,
inner: usize,
shape: &[usize],
) -> Result<ArrayD<i64>, GatherScatterError> {
let total = outer * k * inner;
let mut out: Vec<i64> = vec![0; total];
for o in 0..outer {
for ki in 0..k {
for i in 0..inner {
let src = o * inner * k + i * k + ki;
let dst = o * k * inner + ki * inner + i;
out[dst] = data[src];
}
}
}
Array::from_shape_vec(IxDyn(shape), out).map_err(|_| GatherScatterError::ShapeMismatch {
expected: shape.to_vec(),
got: vec![],
})
}
fn compute_outer_multi_index(shape: &[usize], axis: usize, flat: usize) -> Vec<usize> {
let mut result = vec![0usize; axis];
let mut remaining = flat;
for d in (0..axis).rev() {
result[d] = remaining % shape[d];
remaining /= shape[d];
}
result
}
fn compute_inner_multi_index(shape: &[usize], axis: usize, flat: usize) -> Vec<usize> {
let ndim = shape.len();
let inner_dims = &shape[axis + 1..ndim];
let mut result = vec![0usize; inner_dims.len()];
let mut remaining = flat;
for d in (0..inner_dims.len()).rev() {
result[d] = remaining % inner_dims[d];
remaining /= inner_dims[d];
}
result
}
pub fn masked_select(
input: &ArrayD<f64>,
mask: &ArrayD<bool>,
) -> Result<Array1<f64>, GatherScatterError> {
if input.shape() != mask.shape() {
return Err(GatherScatterError::ShapeMismatch {
expected: input.shape().to_vec(),
got: mask.shape().to_vec(),
});
}
let selected: Vec<f64> = input
.iter()
.zip(mask.iter())
.filter_map(|(&v, &m)| if m { Some(v) } else { None })
.collect();
Ok(Array1::from(selected))
}
pub fn masked_fill(
input: &ArrayD<f64>,
mask: &ArrayD<bool>,
fill_value: f64,
) -> Result<ArrayD<f64>, GatherScatterError> {
if input.shape() != mask.shape() {
return Err(GatherScatterError::ShapeMismatch {
expected: input.shape().to_vec(),
got: mask.shape().to_vec(),
});
}
let data: Vec<f64> = input
.iter()
.zip(mask.iter())
.map(|(&v, &m)| if m { fill_value } else { v })
.collect();
Array::from_shape_vec(IxDyn(input.shape()), data).map_err(|_| {
GatherScatterError::ShapeMismatch {
expected: input.shape().to_vec(),
got: vec![],
}
})
}
#[derive(Debug, Clone)]
pub struct IndexStats {
pub min_index: i64,
pub max_index: i64,
pub unique_indices: usize,
pub total_indices: usize,
pub has_duplicates: bool,
pub has_negatives: bool,
pub coverage: f64,
}
impl IndexStats {
pub fn compute(indices: &[i64], output_size: Option<usize>) -> Self {
if indices.is_empty() {
return IndexStats {
min_index: 0,
max_index: 0,
unique_indices: 0,
total_indices: 0,
has_duplicates: false,
has_negatives: false,
coverage: f64::NAN,
};
}
let mut min_index = indices[0];
let mut max_index = indices[0];
let mut has_negatives = false;
let mut unique_set = HashSet::new();
for &idx in indices {
if idx < min_index {
min_index = idx;
}
if idx > max_index {
max_index = idx;
}
if idx < 0 {
has_negatives = true;
}
unique_set.insert(idx);
}
let unique_indices = unique_set.len();
let total_indices = indices.len();
let has_duplicates = unique_indices < total_indices;
let coverage = match output_size {
Some(sz) if sz > 0 => unique_indices as f64 / sz as f64,
_ => f64::NAN,
};
IndexStats {
min_index,
max_index,
unique_indices,
total_indices,
has_duplicates,
has_negatives,
coverage,
}
}
pub fn is_permutation(&self, size: usize) -> bool {
if self.total_indices != size || self.has_duplicates || self.has_negatives {
return false;
}
if size == 0 {
return true;
}
self.min_index == 0 && self.max_index == (size as i64 - 1)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{Array, IxDyn};
fn make_2d(rows: Vec<Vec<f64>>) -> ArrayD<f64> {
let nrows = rows.len();
let ncols = rows[0].len();
let flat: Vec<f64> = rows.into_iter().flatten().collect();
Array::from_shape_vec(IxDyn(&[nrows, ncols]), flat).expect("make_2d: shape mismatch")
}
fn make_1d(data: Vec<f64>) -> ArrayD<f64> {
let n = data.len();
Array::from_shape_vec(IxDyn(&[n]), data).expect("make_1d")
}
#[test]
fn test_gather_axis0_basic() {
let input = make_2d(vec![
vec![0.0, 1.0, 2.0],
vec![3.0, 4.0, 5.0],
vec![6.0, 7.0, 8.0],
vec![9.0, 10.0, 11.0],
]);
let result = gather(&input, &[1, 3], 0).expect("gather axis0");
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(result[[0, 0]], 3.0);
assert_eq!(result[[0, 2]], 5.0);
assert_eq!(result[[1, 0]], 9.0);
assert_eq!(result[[1, 2]], 11.0);
}
#[test]
fn test_gather_axis1_basic() {
let input = make_2d(vec![
vec![0.0, 1.0, 2.0, 3.0, 4.0],
vec![5.0, 6.0, 7.0, 8.0, 9.0],
vec![10.0, 11.0, 12.0, 13.0, 14.0],
]);
let result = gather(&input, &[0, 2, 4], 1).expect("gather axis1");
assert_eq!(result.shape(), &[3, 3]);
assert_eq!(result[[0, 1]], 2.0);
assert_eq!(result[[2, 2]], 14.0);
}
#[test]
fn test_gather_negative_index() {
let input = make_2d(vec![
vec![0.0, 1.0],
vec![2.0, 3.0],
vec![4.0, 5.0],
vec![6.0, 7.0],
]);
let result = gather(&input, &[-1], 0).expect("gather negative");
assert_eq!(result.shape(), &[1, 2]);
assert_eq!(result[[0, 0]], 6.0);
assert_eq!(result[[0, 1]], 7.0);
}
#[test]
fn test_gather_out_of_bounds() {
let input = make_2d(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let err = gather(&input, &[5], 0).unwrap_err();
matches!(err, GatherScatterError::OutOfBoundsIndex { .. });
}
#[test]
fn test_gather_nd_basic() {
let input = make_2d(vec![
vec![10.0, 20.0, 30.0, 40.0],
vec![50.0, 60.0, 70.0, 80.0],
vec![90.0, 100.0, 110.0, 120.0],
]);
let idx_data: Vec<i64> = vec![3, 2, 1, 0, 0, 1, 2, 3, 2, 2, 2, 2];
let indices = Array::from_shape_vec(IxDyn(&[3, 4]), idx_data).unwrap();
let result = gather_nd(&input, &indices, 1).expect("gather_nd");
assert_eq!(result.shape(), &[3, 4]);
assert_eq!(result[[0, 0]], 40.0);
assert_eq!(result[[0, 1]], 30.0);
assert_eq!(result[[2, 0]], 110.0);
}
#[test]
fn test_scatter_add_basic() {
let input = make_2d(vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![6.0, 7.0, 8.0, 9.0, 10.0],
vec![11.0, 12.0, 13.0, 14.0, 15.0],
]);
let result = scatter_add(&input, &[0, 2, 1], 0, 4, 0.0).expect("scatter_add basic");
assert_eq!(result.shape(), &[4, 5]);
assert_eq!(result[[0, 0]], 1.0);
assert_eq!(result[[1, 0]], 11.0);
assert_eq!(result[[2, 0]], 6.0);
assert_eq!(result[[3, 0]], 0.0);
}
#[test]
fn test_scatter_add_duplicate_indices() {
let input = make_2d(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
let result = scatter_add(&input, &[0, 0, 1], 0, 2, 0.0).expect("scatter_add dup");
assert_eq!(result.shape(), &[2, 2]);
assert_eq!(result[[0, 0]], 4.0);
assert_eq!(result[[0, 1]], 6.0);
assert_eq!(result[[1, 0]], 5.0);
}
#[test]
fn test_scatter_add_shape() {
let input = make_1d(vec![1.0, 2.0, 3.0]);
let result = scatter_add(&input, &[0, 2, 4], 0, 6, 0.0).expect("scatter_add shape");
assert_eq!(result.shape(), &[6]);
}
#[test]
fn test_scatter_max_basic() {
let input = make_2d(vec![vec![5.0, 1.0], vec![3.0, 9.0], vec![7.0, 2.0]]);
let result = scatter_max(&input, &[0, 0, 1], 0, 2, f64::NEG_INFINITY).expect("scatter_max");
assert_eq!(result[[0, 0]], 5.0);
assert_eq!(result[[0, 1]], 9.0);
assert_eq!(result[[1, 0]], 7.0);
assert_eq!(result[[1, 1]], 2.0);
}
#[test]
fn test_scatter_min_basic() {
let input = make_2d(vec![vec![5.0, 1.0], vec![3.0, 9.0], vec![7.0, 2.0]]);
let result = scatter_min(&input, &[0, 0, 1], 0, 2, f64::INFINITY).expect("scatter_min");
assert_eq!(result[[0, 0]], 3.0);
assert_eq!(result[[0, 1]], 1.0);
assert_eq!(result[[1, 0]], 7.0);
}
#[test]
fn test_top_k_largest() {
let input = make_1d(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
let (vals, idxs) = top_k(&input, 2, 0, true).expect("top_k largest");
assert_eq!(vals.shape(), &[2]);
assert_eq!(vals[[0]], 9.0);
assert_eq!(vals[[1]], 5.0);
assert_eq!(idxs[[0]], 5); assert_eq!(idxs[[1]], 4); }
#[test]
fn test_top_k_smallest() {
let input = make_1d(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
let (vals, _idxs) = top_k(&input, 2, 0, false).expect("top_k smallest");
assert_eq!(vals.shape(), &[2]);
assert_eq!(vals[[0]], 1.0);
assert_eq!(vals[[1]], 1.0);
}
#[test]
fn test_top_k_k_larger_than_dim() {
let input = make_1d(vec![1.0, 2.0]);
let err = top_k(&input, 5, 0, true).unwrap_err();
matches!(err, GatherScatterError::KTooLarge { .. });
}
#[test]
fn test_masked_select_basic() {
let input = make_2d(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]);
let mask_data = vec![true, false, true, false, true, false];
let mask = Array::from_shape_vec(IxDyn(&[2, 3]), mask_data).expect("mask shape");
let result = masked_select(&input, &mask).expect("masked_select");
assert_eq!(result.len(), 3);
assert_eq!(result[0], 1.0);
assert_eq!(result[1], 3.0);
assert_eq!(result[2], 5.0);
}
#[test]
fn test_masked_select_all_false() {
let input = make_1d(vec![1.0, 2.0, 3.0]);
let mask = Array::from_shape_vec(IxDyn(&[3]), vec![false, false, false]).expect("mask");
let result = masked_select(&input, &mask).expect("masked_select all_false");
assert_eq!(result.len(), 0);
}
#[test]
fn test_masked_fill_basic() {
let input = make_1d(vec![1.0, 2.0, 3.0, 4.0]);
let mask =
Array::from_shape_vec(IxDyn(&[4]), vec![false, true, false, true]).expect("mask");
let result = masked_fill(&input, &mask, -99.0).expect("masked_fill");
assert_eq!(result.shape(), &[4]);
assert_eq!(result[[0]], 1.0);
assert_eq!(result[[1]], -99.0);
assert_eq!(result[[2]], 3.0);
assert_eq!(result[[3]], -99.0);
}
#[test]
fn test_masked_fill_shape_mismatch() {
let input = make_1d(vec![1.0, 2.0, 3.0]);
let mask = Array::from_shape_vec(IxDyn(&[2]), vec![true, false]).expect("mask");
let err = masked_fill(&input, &mask, 0.0).unwrap_err();
matches!(err, GatherScatterError::ShapeMismatch { .. });
}
#[test]
fn test_index_stats_basic() {
let indices: Vec<i64> = vec![0, 2, 4, 2, -1];
let stats = IndexStats::compute(&indices, Some(5));
assert_eq!(stats.min_index, -1);
assert_eq!(stats.max_index, 4);
assert_eq!(stats.total_indices, 5);
assert_eq!(stats.unique_indices, 4); assert!(stats.has_duplicates);
assert!(stats.has_negatives);
assert!((stats.coverage - 0.8).abs() < 1e-10);
}
#[test]
fn test_index_stats_is_permutation_true() {
let indices: Vec<i64> = vec![0, 1, 2, 3];
let stats = IndexStats::compute(&indices, Some(4));
assert!(stats.is_permutation(4));
}
#[test]
fn test_index_stats_is_permutation_false() {
let indices: Vec<i64> = vec![0, 0, 1];
let stats = IndexStats::compute(&indices, Some(3));
assert!(!stats.is_permutation(3));
}
}