use num_traits::Float;
use crate::array::owned::Array;
use crate::array::view::ArrayView;
use crate::dimension::{Axis, Dimension, IxDyn};
use crate::dtype::Element;
use crate::error::FerrayResult;
#[inline]
fn reduce_step<T: PartialOrd + Copy>(acc: T, x: T, take_min: bool) -> T {
let acc_is_nan = acc.partial_cmp(&acc).is_none();
if acc_is_nan {
return acc;
}
let x_is_nan = x.partial_cmp(&x).is_none();
if x_is_nan {
return x;
}
match (take_min, x.partial_cmp(&acc)) {
(true, Some(std::cmp::Ordering::Less)) => x,
(false, Some(std::cmp::Ordering::Greater)) => x,
_ => acc,
}
}
impl<T, D> Array<T, D>
where
T: Element + Copy,
D: Dimension,
{
pub fn sum(&self) -> T
where
T: std::ops::Add<Output = T>,
{
let mut acc = T::zero();
for &x in self.iter() {
acc = acc + x;
}
acc
}
pub fn sum_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
where
T: std::ops::Add<Output = T>,
D::NdarrayDim: ndarray::RemoveAxis,
{
self.fold_axis(axis, T::zero(), |acc, &x| *acc + x)
}
pub fn prod(&self) -> T
where
T: std::ops::Mul<Output = T>,
{
let mut acc = T::one();
for &x in self.iter() {
acc = acc * x;
}
acc
}
pub fn prod_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
where
T: std::ops::Mul<Output = T>,
D::NdarrayDim: ndarray::RemoveAxis,
{
self.fold_axis(axis, T::one(), |acc, &x| *acc * x)
}
}
impl<T, D> Array<T, D>
where
T: Element + Copy + PartialOrd,
D: Dimension,
{
pub fn min(&self) -> Option<T> {
let mut iter = self.iter().copied();
let first = iter.next()?;
Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
}
pub fn max(&self) -> Option<T> {
let mut iter = self.iter().copied();
let first = iter.next()?;
Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
}
pub fn min_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
where
D::NdarrayDim: ndarray::RemoveAxis,
{
let ndim = self.ndim();
if axis.index() >= ndim {
return Err(crate::error::FerrayError::axis_out_of_bounds(
axis.index(),
ndim,
));
}
if self.shape()[axis.index()] == 0 {
return Err(crate::error::FerrayError::shape_mismatch(
"cannot compute min along empty axis",
));
}
self.fold_axis_min_max(axis, true)
}
pub fn max_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
where
D::NdarrayDim: ndarray::RemoveAxis,
{
let ndim = self.ndim();
if axis.index() >= ndim {
return Err(crate::error::FerrayError::axis_out_of_bounds(
axis.index(),
ndim,
));
}
if self.shape()[axis.index()] == 0 {
return Err(crate::error::FerrayError::shape_mismatch(
"cannot compute max along empty axis",
));
}
self.fold_axis_min_max(axis, false)
}
fn fold_axis_min_max(&self, axis: Axis, take_min: bool) -> FerrayResult<Array<T, IxDyn>>
where
D::NdarrayDim: ndarray::RemoveAxis,
{
let nd_axis = ndarray::Axis(axis.index());
let lanes = self.inner.lanes(nd_axis);
let mut out: Vec<T> = Vec::with_capacity(lanes.into_iter().len());
for lane in self.inner.lanes(nd_axis) {
let mut iter = lane.iter().copied();
let first = iter.next().unwrap(); let result = iter.fold(first, |acc, x| reduce_step(acc, x, take_min));
out.push(result);
}
let mut out_shape: Vec<usize> = self.shape().to_vec();
out_shape.remove(axis.index());
Array::from_vec(IxDyn::from(&out_shape[..]), out)
}
}
impl<T, D> Array<T, D>
where
T: Element + Float,
D: Dimension,
{
pub fn mean(&self) -> Option<T> {
let n = self.size();
if n == 0 {
return None;
}
let sum: T = self
.iter()
.copied()
.fold(<T as Element>::zero(), |acc, x| acc + x);
Some(sum / T::from(n).unwrap())
}
pub fn mean_axis(&self, axis: Axis) -> FerrayResult<Array<T, IxDyn>>
where
D::NdarrayDim: ndarray::RemoveAxis,
{
let ndim = self.ndim();
if axis.index() >= ndim {
return Err(crate::error::FerrayError::axis_out_of_bounds(
axis.index(),
ndim,
));
}
let n = self.shape()[axis.index()];
if n == 0 {
return Err(crate::error::FerrayError::shape_mismatch(
"cannot compute mean along empty axis",
));
}
let sums = self.sum_axis(axis)?;
let n_t = T::from(n).unwrap();
Ok(sums.mapv(|x| x / n_t))
}
pub fn var(&self, ddof: usize) -> Option<T> {
let n = self.size();
if n == 0 || ddof >= n {
return None;
}
let mean = self.mean()?;
let sum_sq: T = self.iter().copied().fold(<T as Element>::zero(), |acc, x| {
acc + (x - mean) * (x - mean)
});
Some(sum_sq / T::from(n - ddof).unwrap())
}
pub fn std(&self, ddof: usize) -> Option<T> {
self.var(ddof).map(num_traits::Float::sqrt)
}
}
impl<D> Array<bool, D>
where
D: Dimension,
{
pub fn any(&self) -> bool {
self.iter().any(|&x| x)
}
pub fn all(&self) -> bool {
self.iter().all(|&x| x)
}
}
impl<T, D> ArrayView<'_, T, D>
where
T: Element + Copy,
D: Dimension,
{
pub fn sum(&self) -> T
where
T: std::ops::Add<Output = T>,
{
let mut acc = T::zero();
for &x in self.iter() {
acc = acc + x;
}
acc
}
pub fn prod(&self) -> T
where
T: std::ops::Mul<Output = T>,
{
let mut acc = T::one();
for &x in self.iter() {
acc = acc * x;
}
acc
}
}
impl<T, D> ArrayView<'_, T, D>
where
T: Element + Copy + PartialOrd,
D: Dimension,
{
pub fn min(&self) -> Option<T> {
let mut iter = self.iter().copied();
let first = iter.next()?;
Some(iter.fold(first, |acc, x| reduce_step(acc, x, true)))
}
pub fn max(&self) -> Option<T> {
let mut iter = self.iter().copied();
let first = iter.next()?;
Some(iter.fold(first, |acc, x| reduce_step(acc, x, false)))
}
}
impl<T, D> ArrayView<'_, T, D>
where
T: Element + Float,
D: Dimension,
{
pub fn mean(&self) -> Option<T> {
let n = self.size();
if n == 0 {
return None;
}
let sum: T = self
.iter()
.copied()
.fold(<T as Element>::zero(), |acc, x| acc + x);
Some(sum / T::from(n).unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::{Ix1, Ix2};
fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
fn arr2(rows: usize, cols: usize, data: Vec<f64>) -> Array<f64, Ix2> {
Array::from_vec(Ix2::new([rows, cols]), data).unwrap()
}
#[test]
fn sum_1d() {
let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(a.sum(), 10.0);
}
#[test]
fn sum_empty_returns_zero() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
assert_eq!(a.sum(), 0.0);
}
#[test]
fn sum_axis_2d() {
let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let s0 = a.sum_axis(Axis(0)).unwrap();
assert_eq!(s0.shape(), &[3]);
assert_eq!(s0.iter().copied().collect::<Vec<_>>(), vec![5.0, 7.0, 9.0]);
let s1 = a.sum_axis(Axis(1)).unwrap();
assert_eq!(s1.shape(), &[2]);
assert_eq!(s1.iter().copied().collect::<Vec<_>>(), vec![6.0, 15.0]);
}
#[test]
fn prod_1d() {
let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(a.prod(), 24.0);
}
#[test]
fn prod_empty_returns_one() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
assert_eq!(a.prod(), 1.0);
}
#[test]
fn prod_axis_2d() {
let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let p0 = a.prod_axis(Axis(0)).unwrap();
assert_eq!(
p0.iter().copied().collect::<Vec<_>>(),
vec![4.0, 10.0, 18.0]
);
let p1 = a.prod_axis(Axis(1)).unwrap();
assert_eq!(p1.iter().copied().collect::<Vec<_>>(), vec![6.0, 120.0]);
}
#[test]
fn min_max_1d() {
let a = arr1(vec![3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0]);
assert_eq!(a.min(), Some(1.0));
assert_eq!(a.max(), Some(9.0));
}
#[test]
fn min_max_empty_returns_none() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
assert_eq!(a.min(), None);
assert_eq!(a.max(), None);
}
#[test]
fn min_max_int() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![3, -1, 4, -5, 2]).unwrap();
assert_eq!(a.min(), Some(-5));
assert_eq!(a.max(), Some(4));
}
#[test]
fn min_max_axis_2d() {
let a = arr2(2, 3, vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0]);
let mn0 = a.min_axis(Axis(0)).unwrap();
assert_eq!(mn0.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0, 3.0]);
let mx0 = a.max_axis(Axis(0)).unwrap();
assert_eq!(mx0.iter().copied().collect::<Vec<_>>(), vec![4.0, 5.0, 6.0]);
let mn1 = a.min_axis(Axis(1)).unwrap();
assert_eq!(mn1.iter().copied().collect::<Vec<_>>(), vec![1.0, 2.0]);
let mx1 = a.max_axis(Axis(1)).unwrap();
assert_eq!(mx1.iter().copied().collect::<Vec<_>>(), vec![5.0, 6.0]);
}
#[test]
fn mean_1d() {
let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(a.mean(), Some(2.5));
}
#[test]
fn mean_empty_returns_none() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
assert_eq!(a.mean(), None);
}
#[test]
fn mean_axis_2d() {
let a = arr2(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let m0 = a.mean_axis(Axis(0)).unwrap();
assert_eq!(m0.iter().copied().collect::<Vec<_>>(), vec![2.5, 3.5, 4.5]);
let m1 = a.mean_axis(Axis(1)).unwrap();
assert_eq!(m1.iter().copied().collect::<Vec<_>>(), vec![2.0, 5.0]);
}
#[test]
fn var_population() {
let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(a.var(0), Some(2.0));
}
#[test]
fn var_sample() {
let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(a.var(1), Some(2.5));
}
#[test]
fn std_basic() {
let a = arr1(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let s = a.std(0).unwrap();
assert!((s - 2.0_f64.sqrt()).abs() < 1e-12);
}
#[test]
fn var_ddof_too_large_returns_none() {
let a = arr1(vec![1.0, 2.0]);
assert_eq!(a.var(2), None);
assert_eq!(a.var(5), None);
}
#[test]
fn any_all_bool() {
let true_arr = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
let mixed = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, false, true]).unwrap();
let false_arr =
Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
let empty = Array::<bool, Ix1>::from_vec(Ix1::new([0]), vec![]).unwrap();
assert!(true_arr.all());
assert!(true_arr.any());
assert!(!mixed.all());
assert!(mixed.any());
assert!(!false_arr.all());
assert!(!false_arr.any());
assert!(empty.all());
assert!(!empty.any());
}
#[test]
fn view_sum_min_max_mean() {
let a = arr1(vec![1.0, 2.0, 3.0, 4.0]);
let v = a.view();
assert_eq!(v.sum(), 10.0);
assert_eq!(v.min(), Some(1.0));
assert_eq!(v.max(), Some(4.0));
assert_eq!(v.mean(), Some(2.5));
}
#[test]
fn nan_propagates_in_min_max() {
let a = arr1(vec![1.0, f64::NAN, 3.0]);
assert!(a.min().unwrap().is_nan());
assert!(a.max().unwrap().is_nan());
let b = arr1(vec![f64::NAN, 1.0, 3.0]);
assert!(b.min().unwrap().is_nan());
assert!(b.max().unwrap().is_nan());
let c = arr1(vec![1.0, 3.0, f64::NAN]);
assert!(c.min().unwrap().is_nan());
assert!(c.max().unwrap().is_nan());
}
}