use crate::array::owned::Array;
use crate::array::view::ArrayView;
use crate::dimension::broadcast::broadcast_shapes;
use crate::dimension::{Dimension, IxDyn};
use crate::dtype::Element;
use crate::error::{FerrayError, FerrayResult};
pub struct NdIter;
impl NdIter {
pub fn binary_map<T, U, D1, D2, F>(
a: &Array<T, D1>,
b: &Array<T, D2>,
f: F,
) -> FerrayResult<Array<U, IxDyn>>
where
T: Element + Copy,
U: Element,
D1: Dimension,
D2: Dimension,
F: Fn(T, T) -> U,
{
let shape = broadcast_shapes(a.shape(), b.shape())?;
let a_view = a.broadcast_to(&shape)?;
let b_view = b.broadcast_to(&shape)?;
let total: usize = shape.iter().product();
let mut data = Vec::with_capacity(total);
if let (Some(a_slice), Some(b_slice)) = (a_view.as_slice(), b_view.as_slice()) {
for (&ai, &bi) in a_slice.iter().zip(b_slice.iter()) {
data.push(f(ai, bi));
}
} else {
for (&ai, &bi) in a_view.iter().zip(b_view.iter()) {
data.push(f(ai, bi));
}
}
Array::from_vec(IxDyn::from(&shape[..]), data)
}
pub fn binary_map_mixed<A, B, U, D1, D2, F>(
a: &Array<A, D1>,
b: &Array<B, D2>,
f: F,
) -> FerrayResult<Array<U, IxDyn>>
where
A: Element + Copy,
B: Element + Copy,
U: Element,
D1: Dimension,
D2: Dimension,
F: Fn(A, B) -> U,
{
let shape = broadcast_shapes(a.shape(), b.shape())?;
let a_view = a.broadcast_to(&shape)?;
let b_view = b.broadcast_to(&shape)?;
let total: usize = shape.iter().product();
let mut data = Vec::with_capacity(total);
for (&ai, &bi) in a_view.iter().zip(b_view.iter()) {
data.push(f(ai, bi));
}
Array::from_vec(IxDyn::from(&shape[..]), data)
}
pub fn binary_map_into<T, D1, D2>(
a: &Array<T, D1>,
b: &Array<T, D2>,
out: &mut Array<T, IxDyn>,
f: impl Fn(T, T) -> T,
) -> FerrayResult<()>
where
T: Element + Copy,
D1: Dimension,
D2: Dimension,
{
let shape = broadcast_shapes(a.shape(), b.shape())?;
if out.shape() != &shape[..] {
return Err(FerrayError::shape_mismatch(format!(
"output shape {:?} does not match broadcast shape {:?}",
out.shape(),
shape
)));
}
let a_view = a.broadcast_to(&shape)?;
let b_view = b.broadcast_to(&shape)?;
if let Some(out_slice) = out.as_slice_mut() {
for ((o, &ai), &bi) in out_slice.iter_mut().zip(a_view.iter()).zip(b_view.iter()) {
*o = f(ai, bi);
}
} else {
for ((&ai, &bi), o) in a_view.iter().zip(b_view.iter()).zip(out.iter_mut()) {
*o = f(ai, bi);
}
}
Ok(())
}
pub fn unary_map<T, U, D, F>(a: &Array<T, D>, f: F) -> FerrayResult<Array<U, IxDyn>>
where
T: Element + Copy,
U: Element,
D: Dimension,
F: Fn(T) -> U,
{
let shape = a.shape().to_vec();
let total: usize = shape.iter().product();
let mut data = Vec::with_capacity(total);
if let Some(slice) = a.as_slice() {
for &x in slice {
data.push(f(x));
}
} else {
for &x in a.iter() {
data.push(f(x));
}
}
Array::from_vec(IxDyn::from(&shape[..]), data)
}
pub fn unary_map_into<T, D>(
a: &Array<T, D>,
out: &mut Array<T, IxDyn>,
f: impl Fn(T) -> T,
) -> FerrayResult<()>
where
T: Element + Copy,
D: Dimension,
{
if a.shape() != out.shape() {
return Err(FerrayError::shape_mismatch(format!(
"input shape {:?} does not match output shape {:?}",
a.shape(),
out.shape()
)));
}
if let (Some(in_slice), Some(out_slice)) = (a.as_slice(), out.as_slice_mut()) {
for (o, &x) in out_slice.iter_mut().zip(in_slice.iter()) {
*o = f(x);
}
} else {
for (o, &x) in out.iter_mut().zip(a.iter()) {
*o = f(x);
}
}
Ok(())
}
pub fn broadcast_shape(a_shape: &[usize], b_shape: &[usize]) -> FerrayResult<Vec<usize>> {
broadcast_shapes(a_shape, b_shape)
}
pub fn binary_iter<'a, T, D1, D2>(
a: &'a Array<T, D1>,
b: &'a Array<T, D2>,
) -> FerrayResult<BinaryBroadcastIter<'a, T>>
where
T: Element + Copy,
D1: Dimension,
D2: Dimension,
{
let shape = broadcast_shapes(a.shape(), b.shape())?;
let a_view = a.broadcast_to(&shape)?;
let b_view = b.broadcast_to(&shape)?;
let a_data: Vec<T> = a_view.iter().copied().collect();
let b_data: Vec<T> = b_view.iter().copied().collect();
Ok(BinaryBroadcastIter {
a_view,
b_view,
a_data,
b_data,
index: 0,
})
}
}
pub struct BinaryBroadcastIter<'a, T: Element> {
a_view: ArrayView<'a, T, IxDyn>,
b_view: ArrayView<'a, T, IxDyn>,
a_data: Vec<T>,
b_data: Vec<T>,
index: usize,
}
impl<T: Element + Copy> Iterator for BinaryBroadcastIter<'_, T> {
type Item = (T, T);
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.a_data.len() {
return None;
}
let i = self.index;
self.index += 1;
Some((self.a_data[i], self.b_data[i]))
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.a_data.len() - self.index;
(remaining, Some(remaining))
}
}
impl<T: Element + Copy> ExactSizeIterator for BinaryBroadcastIter<'_, T> {}
impl<T: Element> BinaryBroadcastIter<'_, T> {
pub fn map_collect<U, F>(self, f: F) -> Vec<U>
where
T: Copy,
F: Fn(T, T) -> U,
{
self.a_view
.iter()
.zip(self.b_view.iter())
.map(|(&a, &b)| f(a, b))
.collect()
}
pub fn for_each<F>(self, mut f: F)
where
T: Copy,
F: FnMut(T, T),
{
for (&a, &b) in self.a_view.iter().zip(self.b_view.iter()) {
f(a, b);
}
}
#[must_use]
pub fn shape(&self) -> &[usize] {
self.a_view.shape()
}
#[must_use]
pub fn size(&self) -> usize {
self.a_view.size()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::{Ix1, Ix2};
#[test]
fn binary_map_same_shape() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
assert_eq!(c.shape(), &[3]);
let data: Vec<f64> = c.iter().copied().collect();
assert_eq!(data, vec![11.0, 22.0, 33.0]);
}
#[test]
fn binary_map_broadcast_1d_to_2d() {
let a = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
assert_eq!(c.shape(), &[3, 4]);
let data: Vec<f64> = c.iter().copied().collect();
assert_eq!(
data,
vec![
11.0, 21.0, 31.0, 41.0, 12.0, 22.0, 32.0, 42.0, 13.0, 23.0, 33.0, 43.0
]
);
}
#[test]
fn binary_map_broadcast_scalar() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([1]), vec![100.0]).unwrap();
let c = NdIter::binary_map(&a, &b, |x, y| x * y).unwrap();
assert_eq!(c.shape(), &[3]);
let data: Vec<f64> = c.iter().copied().collect();
assert_eq!(data, vec![100.0, 200.0, 300.0]);
}
#[test]
fn binary_map_incompatible_shapes() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let result = NdIter::binary_map(&a, &b, |x, y| x + y);
assert!(result.is_err());
}
#[test]
fn binary_map_to_bool() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 5.0, 3.0, 7.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, 3.0, 3.0, 6.0]).unwrap();
let c = NdIter::binary_map(&a, &b, |x, y| x > y).unwrap();
assert_eq!(c.shape(), &[4]);
let data: Vec<bool> = c.iter().copied().collect();
assert_eq!(data, vec![false, true, false, true]);
}
#[test]
fn binary_map_into_preallocated() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[3])).unwrap();
NdIter::binary_map_into(&a, &b, &mut out, |x, y| x + y).unwrap();
let data: Vec<f64> = out.iter().copied().collect();
assert_eq!(data, vec![11.0, 22.0, 33.0]);
}
#[test]
fn binary_map_into_wrong_shape_error() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[5])).unwrap();
let result = NdIter::binary_map_into(&a, &b, &mut out, |x, y| x + y);
assert!(result.is_err());
}
#[test]
fn unary_map_basic() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 4.0, 9.0, 16.0]).unwrap();
let c = NdIter::unary_map(&a, f64::sqrt).unwrap();
assert_eq!(c.shape(), &[4]);
let data: Vec<f64> = c.iter().copied().collect();
assert_eq!(data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn unary_map_into_preallocated() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 4.0, 9.0]).unwrap();
let mut out = Array::<f64, IxDyn>::zeros(IxDyn::new(&[3])).unwrap();
NdIter::unary_map_into(&a, &mut out, |x| x * 2.0).unwrap();
let data: Vec<f64> = out.iter().copied().collect();
assert_eq!(data, vec![2.0, 8.0, 18.0]);
}
#[test]
fn binary_iter_shape() {
let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![1.0, 2.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let iter = NdIter::binary_iter(&a, &b).unwrap();
assert_eq!(iter.shape(), &[2, 3]);
}
#[test]
fn binary_iter_map_collect() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let iter = NdIter::binary_iter(&a, &b).unwrap();
let result: Vec<f64> = iter.map_collect(|x, y| x + y);
assert_eq!(result, vec![11.0, 22.0, 33.0]);
}
#[test]
fn binary_map_3d_broadcast() {
use crate::dimension::Ix3;
let a =
Array::<i32, Ix3>::from_vec(Ix3::new([2, 1, 4]), vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap();
let b = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![10, 20, 30]).unwrap();
let c = NdIter::binary_map(&a, &b, |x, y| x + y).unwrap();
assert_eq!(c.shape(), &[2, 3, 4]);
assert_eq!(*c.iter().next().unwrap(), 11);
}
#[test]
fn binary_iter_next() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let mut iter = NdIter::binary_iter(&a, &b).unwrap();
assert_eq!(iter.next(), Some((1.0, 10.0)));
assert_eq!(iter.next(), Some((2.0, 20.0)));
assert_eq!(iter.next(), Some((3.0, 30.0)));
assert_eq!(iter.next(), None);
}
#[test]
fn binary_iter_for_loop() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
let iter = NdIter::binary_iter(&a, &b).unwrap();
let sums: Vec<i32> = iter.map(|(x, y)| x + y).collect();
assert_eq!(sums, vec![11, 22, 33]);
}
#[test]
fn binary_iter_broadcast_with_next() {
let a = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![1.0, 2.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let iter = NdIter::binary_iter(&a, &b).unwrap();
assert_eq!(iter.len(), 6);
let pairs: Vec<(f64, f64)> = iter.collect();
assert_eq!(
pairs,
vec![
(1.0, 10.0),
(1.0, 20.0),
(1.0, 30.0),
(2.0, 10.0),
(2.0, 20.0),
(2.0, 30.0),
]
);
}
#[test]
fn binary_iter_exact_size() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0; 5]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![2.0; 5]).unwrap();
let iter = NdIter::binary_iter(&a, &b).unwrap();
assert_eq!(iter.len(), 5);
}
#[test]
fn binary_iter_for_each_method() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![10.0, 20.0, 30.0]).unwrap();
let iter = NdIter::binary_iter(&a, &b).unwrap();
let mut sum = 0.0;
iter.for_each(|x, y| sum += x + y);
assert!((sum - 66.0).abs() < 1e-10); }
#[test]
fn binary_map_mixed_i32_f64() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![0.5, 1.5, 2.5]).unwrap();
let c = NdIter::binary_map_mixed(&a, &b, |x, y| x as f64 + y).unwrap();
assert_eq!(c.shape(), &[3]);
let data: Vec<f64> = c.iter().copied().collect();
assert_eq!(data, vec![1.5, 3.5, 5.5]);
}
#[test]
fn binary_map_mixed_broadcast() {
let a = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![0.1, 0.2, 0.3, 0.4]).unwrap();
let c = NdIter::binary_map_mixed(&a, &b, |x, y| x as f64 + y).unwrap();
assert_eq!(c.shape(), &[3, 4]);
let first: f64 = *c.iter().next().unwrap();
assert!((first - 1.1).abs() < 1e-10);
}
#[test]
fn binary_map_mixed_to_bool() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![1, 5, 3, 7]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![2.0, 3.0, 3.0, 6.0]).unwrap();
let c = NdIter::binary_map_mixed(&a, &b, |x, y| (x as f64) > y).unwrap();
let data: Vec<bool> = c.iter().copied().collect();
assert_eq!(data, vec![false, true, false, true]);
}
}