use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
fn contiguous_data<T, D>(input: &Array<T, D>) -> Vec<T>
where
T: Element + Copy,
D: Dimension,
{
if let Some(s) = input.as_slice() {
s.to_vec()
} else {
input.iter().copied().collect()
}
}
#[inline]
fn axis_layout(shape: &[usize], axis: usize) -> (usize, usize, usize) {
let axis_len = shape[axis];
let outer_size: usize = shape[..axis].iter().product();
let inner_size: usize = shape[axis + 1..].iter().product();
(outer_size, axis_len, inner_size)
}
pub fn reduce_axis<T, D, F>(
input: &Array<T, D>,
axis: usize,
identity: T,
op: F,
) -> FerrayResult<Array<T, IxDyn>>
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
let ndim = input.ndim();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let shape = input.shape().to_vec();
let (outer_size, axis_len, inner_size) = axis_layout(&shape, axis);
let mut out_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
.collect();
let out_size: usize = out_shape.iter().product::<usize>().max(1);
let data = contiguous_data(input);
let mut result = vec![identity; out_size];
for outer in 0..outer_size {
for inner in 0..inner_size {
let out_idx = outer * inner_size + inner;
let mut acc = identity;
for k in 0..axis_len {
let idx = outer * axis_len * inner_size + k * inner_size + inner;
acc = op(acc, data[idx]);
}
result[out_idx] = acc;
}
}
if out_shape.is_empty() {
out_shape.push(1);
}
Array::from_vec(IxDyn::from(&out_shape[..]), result)
}
pub fn reduce_axis_keepdims<T, D, F>(
input: &Array<T, D>,
axis: usize,
identity: T,
keepdims: bool,
op: F,
) -> FerrayResult<Array<T, IxDyn>>
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
let ndim = input.ndim();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let reduced = reduce_axis(input, axis, identity, op)?;
if !keepdims {
return Ok(reduced);
}
let mut kept_shape: Vec<usize> = input.shape().to_vec();
kept_shape[axis] = 1;
let data: Vec<T> = reduced.iter().copied().collect();
Array::from_vec(IxDyn::new(&kept_shape), data)
}
pub fn reduce_axes<T, D, F>(
input: &Array<T, D>,
axes: &[usize],
identity: T,
keepdims: bool,
op: F,
) -> FerrayResult<Array<T, IxDyn>>
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
let ndim = input.ndim();
let shape: Vec<usize> = input.shape().to_vec();
let mut sorted_axes: Vec<usize> = axes.to_vec();
sorted_axes.sort_unstable();
for window in sorted_axes.windows(2) {
if window[0] == window[1] {
return Err(FerrayError::invalid_value(format!(
"reduce_axes: duplicate axis {}",
window[0]
)));
}
}
for &ax in &sorted_axes {
if ax >= ndim {
return Err(FerrayError::axis_out_of_bounds(ax, ndim));
}
}
let kept_axes: Vec<usize> = (0..ndim)
.filter(|i| sorted_axes.binary_search(i).is_err())
.collect();
let kept_dims: Vec<usize> = kept_axes.iter().map(|&i| shape[i]).collect();
let mut out_shape: Vec<usize> = if keepdims {
shape
.iter()
.enumerate()
.map(|(i, &d)| {
if sorted_axes.binary_search(&i).is_ok() {
1
} else {
d
}
})
.collect()
} else {
kept_dims.clone()
};
let out_size: usize = if keepdims {
out_shape.iter().product()
} else {
kept_dims.iter().product::<usize>().max(1)
};
if sorted_axes.is_empty() {
let data = contiguous_data(input);
return Array::from_vec(IxDyn::new(&shape), data);
}
let mut in_strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
in_strides[i] = in_strides[i + 1] * shape[i + 1];
}
let mut out_strides = vec![1usize; kept_dims.len()];
for i in (0..kept_dims.len().saturating_sub(1)).rev() {
out_strides[i] = out_strides[i + 1] * kept_dims[i + 1];
}
let data = contiguous_data(input);
let mut result = vec![identity; out_size];
for (flat, &x) in data.iter().enumerate() {
let mut rem = flat;
let mut out_flat = 0usize;
let mut kept_pos = 0usize;
for (i, &stride) in in_strides.iter().enumerate() {
let idx = rem / stride;
rem %= stride;
if sorted_axes.binary_search(&i).is_err() {
if !out_strides.is_empty() {
out_flat += idx * out_strides[kept_pos];
}
kept_pos += 1;
}
}
result[out_flat] = op(result[out_flat], x);
}
if out_shape.is_empty() {
out_shape.push(1);
}
Array::from_vec(IxDyn::new(&out_shape), result)
}
pub fn reduce_all<T, D, F>(input: &Array<T, D>, identity: T, op: F) -> T
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
let mut acc = identity;
if let Some(slice) = input.as_slice() {
for &x in slice {
acc = op(acc, x);
}
} else {
for x in input.iter().copied() {
acc = op(acc, x);
}
}
acc
}
pub fn accumulate_axis<T, D, F>(
input: &Array<T, D>,
axis: usize,
op: F,
) -> FerrayResult<Array<T, D>>
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
let ndim = input.ndim();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let shape = input.shape().to_vec();
let (outer_size, axis_len, inner_size) = axis_layout(&shape, axis);
let data = contiguous_data(input);
let mut result = data;
for outer in 0..outer_size {
for inner in 0..inner_size {
let base = outer * axis_len * inner_size + inner;
for k in 1..axis_len {
let prev = result[base + (k - 1) * inner_size];
let cur = result[base + k * inner_size];
result[base + k * inner_size] = op(prev, cur);
}
}
}
Array::from_vec(input.dim().clone(), result)
}
pub fn outer<T, F>(
a: &Array<T, ferray_core::dimension::Ix1>,
b: &Array<T, ferray_core::dimension::Ix1>,
op: F,
) -> FerrayResult<Array<T, IxDyn>>
where
T: Element + Copy,
F: Fn(T, T) -> T,
{
let m = a.size();
let n = b.size();
let mut data = Vec::with_capacity(m * n);
if let (Some(a_slice), Some(b_slice)) = (a.as_slice(), b.as_slice()) {
for &ai in a_slice {
for &bj in b_slice {
data.push(op(ai, bj));
}
}
} else {
for ai in a.iter().copied() {
for bj in b.iter().copied() {
data.push(op(ai, bj));
}
}
}
Array::from_vec(IxDyn::from(&[m, n][..]), data)
}
pub fn at<T, F>(
arr: &mut Array<T, ferray_core::dimension::Ix1>,
indices: &[usize],
values: &[T],
op: F,
) -> FerrayResult<()>
where
T: Element + Copy,
F: Fn(T, T) -> T,
{
if indices.len() != values.len() {
return Err(FerrayError::shape_mismatch(format!(
"at: indices has length {} but values has length {}",
indices.len(),
values.len()
)));
}
let n = arr.size();
let slice = arr
.as_slice_mut()
.ok_or_else(|| FerrayError::invalid_value("at: array must be contiguous (C-order)"))?;
for (&i, &v) in indices.iter().zip(values.iter()) {
if i >= n {
return Err(FerrayError::invalid_value(format!(
"at: index {i} out of bounds for length {n}"
)));
}
slice[i] = op(slice[i], v);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix2;
use crate::test_util::arr1;
fn arr2(rows: usize, cols: usize, data: &[f64]) -> Array<f64, Ix2> {
Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data.to_vec()).unwrap()
}
#[test]
fn reduce_axis_add_1d() {
let a = arr1([1.0, 2.0, 3.0, 4.0]);
let r = reduce_axis(&a, 0, 0.0, |acc, x| acc + x).unwrap();
assert_eq!(r.as_slice().unwrap(), &[10.0]);
}
#[test]
fn reduce_axis_add_2d_rows() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r = reduce_axis(&a, 1, 0.0, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[2]);
assert_eq!(r.as_slice().unwrap(), &[6.0, 15.0]);
}
#[test]
fn reduce_axis_add_2d_cols() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r = reduce_axis(&a, 0, 0.0, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[3]);
assert_eq!(r.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
}
#[test]
fn reduce_axis_multiply_product() {
let a = arr1([1.0, 2.0, 3.0, 4.0]);
let r = reduce_axis(&a, 0, 1.0, |acc, x| acc * x).unwrap();
assert_eq!(r.as_slice().unwrap(), &[24.0]);
}
#[test]
fn reduce_axis_max() {
let a = arr1([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0, 6.0]);
let r = reduce_axis(&a, 0, f64::NEG_INFINITY, f64::max).unwrap();
assert_eq!(r.as_slice().unwrap(), &[9.0]);
}
#[test]
fn reduce_axis_out_of_bounds() {
let a = arr1([1.0, 2.0, 3.0]);
assert!(reduce_axis(&a, 1, 0.0, |x, y| x + y).is_err());
}
#[test]
fn reduce_axis_keepdims_2d_rows_preserves_axis() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r = reduce_axis_keepdims(&a, 1, 0.0, true, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[2, 1]);
assert_eq!(r.as_slice().unwrap(), &[6.0, 15.0]);
}
#[test]
fn reduce_axis_keepdims_2d_cols_preserves_axis() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r = reduce_axis_keepdims(&a, 0, 0.0, true, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[1, 3]);
assert_eq!(r.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
}
#[test]
fn reduce_axis_keepdims_false_matches_reduce_axis() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let with_flag = reduce_axis_keepdims(&a, 1, 0.0, false, |acc, x| acc + x).unwrap();
let without = reduce_axis(&a, 1, 0.0, |acc, x| acc + x).unwrap();
assert_eq!(with_flag.shape(), without.shape());
assert_eq!(with_flag.as_slice().unwrap(), without.as_slice().unwrap());
}
#[test]
fn reduce_axis_keepdims_3d_middle_axis() {
use ferray_core::dimension::Ix3;
let data: Vec<f64> = (0..24).map(f64::from).collect();
let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
let r = reduce_axis_keepdims(&a, 1, 0.0, true, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[2, 1, 4]);
}
#[test]
fn reduce_axis_keepdims_out_of_bounds_errors() {
let a = arr1([1.0, 2.0, 3.0]);
assert!(reduce_axis_keepdims(&a, 5, 0.0, true, |x, y| x + y).is_err());
assert!(reduce_axis_keepdims(&a, 5, 0.0, false, |x, y| x + y).is_err());
}
#[test]
fn reduce_axis_keepdims_result_is_broadcastable() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 10.0, 20.0, 30.0]);
let row_sums = reduce_axis_keepdims(&a, 1, 0.0, true, |acc, x| acc + x).unwrap();
assert_eq!(row_sums.shape(), &[2, 1]);
let row_sums_slice = row_sums.as_slice().unwrap();
assert_eq!(row_sums_slice, &[6.0, 60.0]);
}
#[test]
fn reduce_axes_single_axis_matches_reduce_axis() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let multi = reduce_axes(&a, &[1], 0.0, false, |acc, x| acc + x).unwrap();
let single = reduce_axis(&a, 1, 0.0, |acc, x| acc + x).unwrap();
assert_eq!(multi.shape(), single.shape());
assert_eq!(multi.as_slice().unwrap(), single.as_slice().unwrap());
}
#[test]
fn reduce_axes_two_axes_3d_collapse_to_vector() {
use ferray_core::dimension::Ix3;
let data: Vec<f64> = (0..24).map(f64::from).collect();
let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
let r = reduce_axes(&a, &[0, 2], 0.0, false, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[3]);
let expected: Vec<f64> = (0..3)
.map(|j| {
let mut s = 0.0;
for i in 0..2 {
for k in 0..4 {
s += f64::from(i * 12 + j * 4 + k);
}
}
s
})
.collect();
assert_eq!(r.as_slice().unwrap(), expected.as_slice());
}
#[test]
fn reduce_axes_unsorted_axes_input_works() {
use ferray_core::dimension::Ix3;
let data: Vec<f64> = (0..24).map(f64::from).collect();
let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
let r1 = reduce_axes(&a, &[0, 2], 0.0, false, |acc, x| acc + x).unwrap();
let r2 = reduce_axes(&a, &[2, 0], 0.0, false, |acc, x| acc + x).unwrap();
assert_eq!(r1.shape(), r2.shape());
assert_eq!(r1.as_slice().unwrap(), r2.as_slice().unwrap());
}
#[test]
fn reduce_axes_keepdims_preserves_reduced_axes_as_size_1() {
use ferray_core::dimension::Ix3;
let data: Vec<f64> = (0..24).map(f64::from).collect();
let a = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
let r = reduce_axes(&a, &[0, 2], 0.0, true, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[1, 3, 1]);
}
#[test]
fn reduce_axes_all_axes_collapses_to_scalar() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r = reduce_axes(&a, &[0, 1], 0.0, false, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[1]);
assert_eq!(r.as_slice().unwrap(), &[21.0]);
}
#[test]
fn reduce_axes_all_axes_keepdims_gives_size_1_per_axis() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r = reduce_axes(&a, &[0, 1], 0.0, true, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[1, 1]);
assert_eq!(r.as_slice().unwrap(), &[21.0]);
}
#[test]
fn reduce_axes_empty_axes_is_identity_copy() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r = reduce_axes(&a, &[], 0.0, false, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[2, 3]);
assert_eq!(r.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn reduce_axes_duplicate_axis_errors() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert!(reduce_axes(&a, &[1, 1], 0.0, false, |x, y| x + y).is_err());
}
#[test]
fn reduce_axes_out_of_bounds_errors() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert!(reduce_axes(&a, &[5], 0.0, false, |x, y| x + y).is_err());
assert!(reduce_axes(&a, &[0, 5], 0.0, false, |x, y| x + y).is_err());
}
#[test]
fn reduce_axes_chained_via_single_pass_matches_sequential() {
use ferray_core::dimension::Ix3;
let data: Vec<f64> = (0..60).map(f64::from).collect();
let a = Array::<f64, Ix3>::from_vec(Ix3::new([3, 4, 5]), data).unwrap();
let single_pass = reduce_axes(&a, &[0, 2], 0.0, false, |acc, x| acc + x).unwrap();
let step1 = reduce_axis(&a, 2, 0.0, |acc, x| acc + x).unwrap();
let step2 = reduce_axis(&step1, 0, 0.0, |acc, x| acc + x).unwrap();
assert_eq!(single_pass.shape(), step2.shape());
for (a, b) in single_pass
.as_slice()
.unwrap()
.iter()
.zip(step2.as_slice().unwrap().iter())
{
assert!((a - b).abs() < 1e-10, "{a} vs {b}");
}
}
#[test]
fn reduce_all_sums_full_array() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let s = reduce_all(&a, 0.0, |acc, x| acc + x);
assert!((s - 21.0).abs() < 1e-12);
}
#[test]
fn reduce_all_max_returns_global_max() {
let a = arr2(2, 3, &[3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
let m = reduce_all(&a, f64::NEG_INFINITY, f64::max);
assert!((m - 9.0).abs() < 1e-12);
}
#[test]
fn reduce_all_empty_array_returns_identity() {
use ferray_core::dimension::Ix1;
let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
let s = reduce_all(&a, 0.0, |acc, x| acc + x);
assert_eq!(s, 0.0);
}
#[test]
fn accumulate_axis_add_1d() {
let a = arr1([1.0, 2.0, 3.0, 4.0]);
let r = accumulate_axis(&a, 0, |acc, x| acc + x).unwrap();
assert_eq!(r.as_slice().unwrap(), &[1.0, 3.0, 6.0, 10.0]);
}
#[test]
fn accumulate_axis_multiply_1d() {
let a = arr1([1.0, 2.0, 3.0, 4.0]);
let r = accumulate_axis(&a, 0, |acc, x| acc * x).unwrap();
assert_eq!(r.as_slice().unwrap(), &[1.0, 2.0, 6.0, 24.0]);
}
#[test]
fn accumulate_axis_subtract_running_diff() {
let a = arr1([10.0, 3.0, 2.0, 1.0]);
let r = accumulate_axis(&a, 0, |acc, x| acc - x).unwrap();
assert_eq!(r.as_slice().unwrap(), &[10.0, 7.0, 5.0, 4.0]);
}
#[test]
fn accumulate_axis_2d_rows() {
let a = arr2(2, 3, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let r = accumulate_axis(&a, 1, |acc, x| acc + x).unwrap();
assert_eq!(r.shape(), &[2, 3]);
assert_eq!(r.as_slice().unwrap(), &[1.0, 3.0, 6.0, 4.0, 9.0, 15.0]);
}
#[test]
fn accumulate_axis_out_of_bounds() {
let a = arr1([1.0, 2.0, 3.0]);
assert!(accumulate_axis(&a, 1, |x, y| x + y).is_err());
}
#[test]
fn outer_multiply() {
let a = arr1([1.0, 2.0, 3.0]);
let b = arr1([10.0, 20.0]);
let r = outer(&a, &b, |x, y| x * y).unwrap();
assert_eq!(r.shape(), &[3, 2]);
assert_eq!(r.as_slice().unwrap(), &[10.0, 20.0, 20.0, 40.0, 30.0, 60.0]);
}
#[test]
fn outer_add() {
let a = arr1([1.0, 2.0]);
let b = arr1([10.0, 20.0, 30.0]);
let r = outer(&a, &b, |x, y| x + y).unwrap();
assert_eq!(r.shape(), &[2, 3]);
assert_eq!(r.as_slice().unwrap(), &[11.0, 21.0, 31.0, 12.0, 22.0, 32.0]);
}
#[test]
fn outer_power() {
let a = arr1([2.0, 3.0]);
let b = arr1([2.0, 3.0]);
let r = outer(&a, &b, f64::powf).unwrap();
assert_eq!(r.shape(), &[2, 2]);
assert_eq!(r.as_slice().unwrap(), &[4.0, 8.0, 9.0, 27.0]);
}
#[test]
fn at_add_unbuffered_duplicates() {
let mut a = arr1([0.0, 0.0, 0.0]);
at(&mut a, &[0, 0, 1, 2], &[1.0, 2.0, 5.0, 10.0], |acc, x| {
acc + x
})
.unwrap();
assert_eq!(a.as_slice().unwrap(), &[3.0, 5.0, 10.0]);
}
#[test]
fn at_multiply() {
let mut a = arr1([1.0, 1.0, 1.0, 1.0]);
at(&mut a, &[1, 2, 2], &[5.0, 3.0, 4.0], |acc, x| acc * x).unwrap();
assert_eq!(a.as_slice().unwrap(), &[1.0, 5.0, 12.0, 1.0]);
}
#[test]
fn at_length_mismatch_errors() {
let mut a = arr1([0.0; 4]);
assert!(at(&mut a, &[0, 1], &[1.0], |x, y| x + y).is_err());
}
#[test]
fn at_index_out_of_bounds_errors() {
let mut a = arr1([0.0; 3]);
assert!(at(&mut a, &[5], &[1.0], |x, y| x + y).is_err());
}
}