use crate::dimension::{Axis, Dimension, IxDyn};
use crate::dtype::Element;
use crate::error::{FerrayError, FerrayResult};
use super::owned::Array;
use super::view::ArrayView;
impl<T: Element, D: Dimension> Array<T, D> {
pub fn mapv(&self, f: impl Fn(T) -> T) -> Self {
let inner = self.inner.mapv(&f);
Self::from_ndarray(inner)
}
pub fn mapv_inplace(&mut self, f: impl Fn(T) -> T) {
self.inner.mapv_inplace(&f);
}
pub fn zip_mut_with(
&mut self,
other: &Array<T, D>,
f: impl Fn(&mut T, &T),
) -> FerrayResult<()> {
if self.shape() != other.shape() {
return Err(FerrayError::shape_mismatch(format!(
"cannot zip arrays with shapes {:?} and {:?}",
self.shape(),
other.shape(),
)));
}
self.inner.zip_mut_with(&other.inner, |a, b| f(a, b));
Ok(())
}
pub fn fold_axis(
&self,
axis: Axis,
init: T,
fold: impl FnMut(&T, &T) -> T,
) -> FerrayResult<Array<T, IxDyn>>
where
D::NdarrayDim: ndarray::RemoveAxis,
{
let ndim = self.ndim();
if axis.index() >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
}
let nd_axis = ndarray::Axis(axis.index());
let mut fold = fold;
let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
let dyn_result = result.into_dyn();
Ok(Array::from_ndarray(dyn_result))
}
pub fn map_to<U: Element>(&self, f: impl Fn(T) -> U) -> Array<U, D> {
let inner = self.inner.mapv(&f);
Array::from_ndarray(inner)
}
}
impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
pub fn mapv(&self, f: impl Fn(T) -> T) -> Array<T, D> {
let inner = self.inner.mapv(&f);
Array::from_ndarray(inner)
}
pub fn fold_axis(
&self,
axis: Axis,
init: T,
fold: impl FnMut(&T, &T) -> T,
) -> FerrayResult<Array<T, IxDyn>>
where
D::NdarrayDim: ndarray::RemoveAxis,
{
let ndim = self.ndim();
if axis.index() >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis.index(), ndim));
}
let nd_axis = ndarray::Axis(axis.index());
let mut fold = fold;
let result = self.inner.fold_axis(nd_axis, init, |acc, x| fold(acc, x));
let dyn_result = result.into_dyn();
Ok(Array::from_ndarray(dyn_result))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::{Ix1, Ix2};
#[test]
fn mapv_double() {
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let doubled = arr.mapv(|x| x * 2.0);
assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0, 8.0]);
assert_eq!(arr.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn mapv_inplace_negate() {
let mut arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, -2.0, 3.0]).unwrap();
arr.mapv_inplace(|x| -x);
assert_eq!(arr.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
}
#[test]
fn zip_mut_with_add() {
let mut a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
a.zip_mut_with(&b, |x, y| *x += y).unwrap();
assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
}
#[test]
fn zip_mut_with_shape_mismatch() {
let mut a = Array::<f64, Ix1>::zeros(Ix1::new([3])).unwrap();
let b = Array::<f64, Ix1>::zeros(Ix1::new([5])).unwrap();
assert!(a.zip_mut_with(&b, |_, _| {}).is_err());
}
#[test]
fn fold_axis_sum_rows() {
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let sums = arr.fold_axis(Axis(1), 0.0, |acc, &x| *acc + x).unwrap();
assert_eq!(sums.shape(), &[2]);
let data: Vec<f64> = sums.iter().copied().collect();
assert_eq!(data, vec![6.0, 15.0]);
}
#[test]
fn fold_axis_sum_cols() {
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let sums = arr.fold_axis(Axis(0), 0.0, |acc, &x| *acc + x).unwrap();
assert_eq!(sums.shape(), &[3]);
let data: Vec<f64> = sums.iter().copied().collect();
assert_eq!(data, vec![5.0, 7.0, 9.0]);
}
#[test]
fn fold_axis_out_of_bounds() {
let arr = Array::<f64, Ix2>::zeros(Ix2::new([2, 3])).unwrap();
assert!(arr.fold_axis(Axis(2), 0.0, |a, _| *a).is_err());
}
#[test]
fn map_to_different_type() {
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.5, 2.7, 3.1]).unwrap();
let ints: Array<i32, Ix1> = arr.map_to(|x| x as i32);
assert_eq!(ints.as_slice().unwrap(), &[1, 2, 3]);
}
#[test]
fn view_mapv() {
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let v = arr.view();
let doubled = v.mapv(|x| x * 2.0);
assert_eq!(doubled.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
}
}