use crate::array::owned::Array;
use crate::dimension::Dimension;
use crate::dimension::IxDyn;
use crate::dimension::broadcast::{broadcast_shapes, broadcast_to};
use crate::dtype::Element;
use crate::error::{FerrayError, FerrayResult};
fn elementwise_binary<T, D, F>(
a: &Array<T, D>,
b: &Array<T, D>,
op: F,
op_name: &str,
) -> FerrayResult<Array<T, D>>
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
if a.shape() == b.shape() {
let data: Vec<T> = a.iter().zip(b.iter()).map(|(&x, &y)| op(x, y)).collect();
return Array::from_vec(a.dim().clone(), data);
}
let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
FerrayError::shape_mismatch(format!(
"operator {}: shapes {:?} and {:?} are not broadcast-compatible",
op_name,
a.shape(),
b.shape()
))
})?;
let a_view = broadcast_to(a, &target_shape)?;
let b_view = broadcast_to(b, &target_shape)?;
let data: Vec<T> = a_view
.iter()
.zip(b_view.iter())
.map(|(&x, &y)| op(x, y))
.collect();
let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
FerrayError::shape_mismatch(format!(
"operator {op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
))
})?;
Array::from_vec(result_dim, data)
}
fn elementwise_binary_dyn<T, D1, D2, F>(
a: &Array<T, D1>,
b: &Array<T, D2>,
op: F,
op_name: &str,
) -> FerrayResult<Array<T, IxDyn>>
where
T: Element + Copy,
D1: Dimension,
D2: Dimension,
F: Fn(T, T) -> T,
{
let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
FerrayError::shape_mismatch(format!(
"{}: shapes {:?} and {:?} are not broadcast-compatible",
op_name,
a.shape(),
b.shape()
))
})?;
let a_view = broadcast_to(a, &target_shape)?;
let b_view = broadcast_to(b, &target_shape)?;
let data: Vec<T> = a_view
.iter()
.zip(b_view.iter())
.map(|(&x, &y)| op(x, y))
.collect();
Array::from_vec(IxDyn::from(&target_shape[..]), data)
}
macro_rules! impl_binary_op {
($trait:ident, $method:ident, $op_fn:expr, $op_name:expr) => {
impl<T, D> std::ops::$trait<&Array<T, D>> for &Array<T, D>
where
T: Element + Copy + std::ops::$trait<Output = T>,
D: Dimension,
{
type Output = FerrayResult<Array<T, D>>;
fn $method(self, rhs: &Array<T, D>) -> Self::Output {
elementwise_binary(self, rhs, $op_fn, $op_name)
}
}
impl<T, D> std::ops::$trait<Array<T, D>> for Array<T, D>
where
T: Element + Copy + std::ops::$trait<Output = T>,
D: Dimension,
{
type Output = FerrayResult<Array<T, D>>;
fn $method(self, rhs: Array<T, D>) -> Self::Output {
elementwise_binary(&self, &rhs, $op_fn, $op_name)
}
}
impl<T, D> std::ops::$trait<&Array<T, D>> for Array<T, D>
where
T: Element + Copy + std::ops::$trait<Output = T>,
D: Dimension,
{
type Output = FerrayResult<Array<T, D>>;
fn $method(self, rhs: &Array<T, D>) -> Self::Output {
elementwise_binary(&self, rhs, $op_fn, $op_name)
}
}
impl<T, D> std::ops::$trait<Array<T, D>> for &Array<T, D>
where
T: Element + Copy + std::ops::$trait<Output = T>,
D: Dimension,
{
type Output = FerrayResult<Array<T, D>>;
fn $method(self, rhs: Array<T, D>) -> Self::Output {
elementwise_binary(self, &rhs, $op_fn, $op_name)
}
}
};
}
impl_binary_op!(Add, add, |a, b| a + b, "+");
impl_binary_op!(Sub, sub, |a, b| a - b, "-");
impl_binary_op!(Mul, mul, |a, b| a * b, "*");
impl_binary_op!(Div, div, |a, b| a / b, "/");
impl_binary_op!(Rem, rem, |a, b| a % b, "%");
macro_rules! impl_scalar_op {
($trait:ident, $method:ident, $op_fn:expr) => {
impl<T, D> std::ops::$trait<T> for &Array<T, D>
where
T: Element + Copy + std::ops::$trait<Output = T>,
D: Dimension,
{
type Output = FerrayResult<Array<T, D>>;
fn $method(self, rhs: T) -> Self::Output {
let data: Vec<T> = self.iter().map(|&x| $op_fn(x, rhs)).collect();
Array::from_vec(self.dim().clone(), data)
}
}
impl<T, D> std::ops::$trait<T> for Array<T, D>
where
T: Element + Copy + std::ops::$trait<Output = T>,
D: Dimension,
{
type Output = FerrayResult<Array<T, D>>;
fn $method(self, rhs: T) -> Self::Output {
(&self).$method(rhs)
}
}
};
}
impl_scalar_op!(Add, add, |a, b| a + b);
impl_scalar_op!(Sub, sub, |a, b| a - b);
impl_scalar_op!(Mul, mul, |a, b| a * b);
impl_scalar_op!(Div, div, |a, b| a / b);
impl_scalar_op!(Rem, rem, |a, b| a % b);
impl<T, D> std::ops::Neg for &Array<T, D>
where
T: Element + Copy + std::ops::Neg<Output = T>,
D: Dimension,
{
type Output = FerrayResult<Array<T, D>>;
fn neg(self) -> Self::Output {
let data: Vec<T> = self.iter().map(|&x| -x).collect();
Array::from_vec(self.dim().clone(), data)
}
}
impl<T, D> std::ops::Neg for Array<T, D>
where
T: Element + Copy + std::ops::Neg<Output = T>,
D: Dimension,
{
type Output = FerrayResult<Self>;
fn neg(self) -> Self::Output {
-&self
}
}
impl<T, D> Array<T, D>
where
T: Element + Copy,
D: Dimension,
{
pub fn add_broadcast<D2: Dimension>(
&self,
other: &Array<T, D2>,
) -> FerrayResult<Array<T, IxDyn>>
where
T: std::ops::Add<Output = T>,
{
elementwise_binary_dyn(self, other, |x, y| x + y, "add_broadcast")
}
pub fn sub_broadcast<D2: Dimension>(
&self,
other: &Array<T, D2>,
) -> FerrayResult<Array<T, IxDyn>>
where
T: std::ops::Sub<Output = T>,
{
elementwise_binary_dyn(self, other, |x, y| x - y, "sub_broadcast")
}
pub fn mul_broadcast<D2: Dimension>(
&self,
other: &Array<T, D2>,
) -> FerrayResult<Array<T, IxDyn>>
where
T: std::ops::Mul<Output = T>,
{
elementwise_binary_dyn(self, other, |x, y| x * y, "mul_broadcast")
}
pub fn div_broadcast<D2: Dimension>(
&self,
other: &Array<T, D2>,
) -> FerrayResult<Array<T, IxDyn>>
where
T: std::ops::Div<Output = T>,
{
elementwise_binary_dyn(self, other, |x, y| x / y, "div_broadcast")
}
pub fn rem_broadcast<D2: Dimension>(
&self,
other: &Array<T, D2>,
) -> FerrayResult<Array<T, IxDyn>>
where
T: std::ops::Rem<Output = T>,
{
elementwise_binary_dyn(self, other, |x, y| x % y, "rem_broadcast")
}
}
macro_rules! impl_scalar_op_assign {
($trait:ident, $method:ident, $op:tt) => {
impl<T, D> std::ops::$trait<T> for Array<T, D>
where
T: Element + Copy + std::ops::$trait,
D: Dimension,
{
fn $method(&mut self, rhs: T) {
self.mapv_inplace(|mut x| {
x $op rhs;
x
});
}
}
};
}
impl_scalar_op_assign!(AddAssign, add_assign, +=);
impl_scalar_op_assign!(SubAssign, sub_assign, -=);
impl_scalar_op_assign!(MulAssign, mul_assign, *=);
impl_scalar_op_assign!(DivAssign, div_assign, /=);
impl_scalar_op_assign!(RemAssign, rem_assign, %=);
fn inplace_binary<T, D, F>(
lhs: &mut Array<T, D>,
rhs: &Array<T, D>,
op: F,
op_name: &str,
) -> FerrayResult<()>
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
if lhs.shape() == rhs.shape() {
return lhs.zip_mut_with(rhs, |a, b| *a = op(*a, *b));
}
let target_shape: Vec<usize> = lhs.shape().to_vec();
let rhs_view = broadcast_to(rhs, &target_shape).map_err(|_| {
FerrayError::shape_mismatch(format!(
"{}: shape {:?} cannot be broadcast into destination shape {:?}",
op_name,
rhs.shape(),
target_shape
))
})?;
for (a, b) in lhs.iter_mut().zip(rhs_view.iter()) {
*a = op(*a, *b);
}
Ok(())
}
impl<T, D> Array<T, D>
where
T: Element + Copy,
D: Dimension,
{
pub fn add_inplace(&mut self, other: &Self) -> FerrayResult<()>
where
T: std::ops::Add<Output = T>,
{
inplace_binary(self, other, |a, b| a + b, "add_inplace")
}
pub fn sub_inplace(&mut self, other: &Self) -> FerrayResult<()>
where
T: std::ops::Sub<Output = T>,
{
inplace_binary(self, other, |a, b| a - b, "sub_inplace")
}
pub fn mul_inplace(&mut self, other: &Self) -> FerrayResult<()>
where
T: std::ops::Mul<Output = T>,
{
inplace_binary(self, other, |a, b| a * b, "mul_inplace")
}
pub fn div_inplace(&mut self, other: &Self) -> FerrayResult<()>
where
T: std::ops::Div<Output = T>,
{
inplace_binary(self, other, |a, b| a / b, "div_inplace")
}
pub fn rem_inplace(&mut self, other: &Self) -> FerrayResult<()>
where
T: std::ops::Rem<Output = T>,
{
inplace_binary(self, other, |a, b| a % b, "rem_inplace")
}
}
pub fn copyto<T, D1, D2>(dst: &mut Array<T, D1>, src: &Array<T, D2>) -> FerrayResult<()>
where
T: Element,
D1: Dimension,
D2: Dimension,
{
if dst.shape() == src.shape() {
for (d, s) in dst.iter_mut().zip(src.iter()) {
*d = s.clone();
}
return Ok(());
}
let target_shape: Vec<usize> = dst.shape().to_vec();
let src_view = broadcast_to(src, &target_shape).map_err(|_| {
FerrayError::shape_mismatch(format!(
"copyto: source shape {:?} cannot be broadcast into destination shape {:?}",
src.shape(),
target_shape
))
})?;
for (d, s) in dst.iter_mut().zip(src_view.iter()) {
*d = s.clone();
}
Ok(())
}
pub fn copyto_where<T, D1, D2, D3>(
dst: &mut Array<T, D1>,
src: &Array<T, D2>,
mask: &Array<bool, D3>,
) -> FerrayResult<()>
where
T: Element,
D1: Dimension,
D2: Dimension,
D3: Dimension,
{
let target_shape: Vec<usize> = dst.shape().to_vec();
let src_view = broadcast_to(src, &target_shape).map_err(|_| {
FerrayError::shape_mismatch(format!(
"copyto_where: source shape {:?} cannot be broadcast into destination shape {:?}",
src.shape(),
target_shape
))
})?;
let mask_view = broadcast_to(mask, &target_shape).map_err(|_| {
FerrayError::shape_mismatch(format!(
"copyto_where: mask shape {:?} cannot be broadcast into destination shape {:?}",
mask.shape(),
target_shape
))
})?;
for ((d, s), &m) in dst.iter_mut().zip(src_view.iter()).zip(mask_view.iter()) {
if m {
*d = s.clone();
}
}
Ok(())
}
impl<T, D> Array<T, D>
where
T: Element,
D: Dimension,
{
pub fn copy_from<D2: Dimension>(&mut self, src: &Array<T, D2>) -> FerrayResult<()> {
copyto(self, src)
}
pub fn copy_from_where<D2: Dimension, D3: Dimension>(
&mut self,
src: &Array<T, D2>,
mask: &Array<bool, D3>,
) -> FerrayResult<()> {
copyto_where(self, src, mask)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::Ix1;
fn arr(data: Vec<f64>) -> Array<f64, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
fn arr_i32(data: Vec<i32>) -> Array<i32, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn test_add_ref_ref() {
let a = arr(vec![1.0, 2.0, 3.0]);
let b = arr(vec![4.0, 5.0, 6.0]);
let c = (&a + &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[5.0, 7.0, 9.0]);
}
#[test]
fn test_add_owned_owned() {
let a = arr(vec![1.0, 2.0]);
let b = arr(vec![3.0, 4.0]);
let c = (a + b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
}
#[test]
fn test_add_mixed() {
let a = arr(vec![1.0, 2.0]);
let b = arr(vec![3.0, 4.0]);
let c = (a + &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[4.0, 6.0]);
let d = arr(vec![10.0, 20.0]);
let e = (&b + d).unwrap();
assert_eq!(e.as_slice().unwrap(), &[13.0, 24.0]);
}
#[test]
fn test_sub() {
let a = arr(vec![5.0, 7.0]);
let b = arr(vec![1.0, 2.0]);
let c = (&a - &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[4.0, 5.0]);
}
#[test]
fn test_mul() {
let a = arr(vec![2.0, 3.0]);
let b = arr(vec![4.0, 5.0]);
let c = (&a * &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[8.0, 15.0]);
}
#[test]
fn test_div() {
let a = arr(vec![10.0, 20.0]);
let b = arr(vec![2.0, 5.0]);
let c = (&a / &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[5.0, 4.0]);
}
#[test]
fn test_rem() {
let a = arr_i32(vec![7, 10]);
let b = arr_i32(vec![3, 4]);
let c = (&a % &b).unwrap();
assert_eq!(c.as_slice().unwrap(), &[1, 2]);
}
#[test]
fn test_neg() {
let a = arr(vec![1.0, -2.0, 3.0]);
let b = (-&a).unwrap();
assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0, -3.0]);
}
#[test]
fn test_neg_owned() {
let a = arr(vec![1.0, -2.0]);
let b = (-a).unwrap();
assert_eq!(b.as_slice().unwrap(), &[-1.0, 2.0]);
}
#[test]
fn test_shape_mismatch_errors() {
let a = arr(vec![1.0, 2.0]);
let b = arr(vec![1.0, 2.0, 3.0]);
let result = &a + &b;
assert!(result.is_err());
}
#[test]
fn test_add_scalar() {
let a = arr(vec![1.0, 2.0, 3.0]);
let c = (&a + 10.0).unwrap();
assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
}
#[test]
fn test_sub_scalar() {
let a = arr(vec![10.0, 20.0, 30.0]);
let c = (&a - 5.0).unwrap();
assert_eq!(c.as_slice().unwrap(), &[5.0, 15.0, 25.0]);
}
#[test]
fn test_mul_scalar() {
let a = arr(vec![1.0, 2.0, 3.0]);
let c = (&a * 3.0).unwrap();
assert_eq!(c.as_slice().unwrap(), &[3.0, 6.0, 9.0]);
}
#[test]
fn test_div_scalar() {
let a = arr(vec![10.0, 20.0, 30.0]);
let c = (&a / 10.0).unwrap();
assert_eq!(c.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn test_rem_scalar() {
let a = arr_i32(vec![7, 10, 15]);
let c = (&a % 4).unwrap();
assert_eq!(c.as_slice().unwrap(), &[3, 2, 3]);
}
#[test]
fn test_scalar_op_owned() {
let a = arr(vec![1.0, 2.0, 3.0]);
let c = (a + 10.0).unwrap();
assert_eq!(c.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
}
#[test]
fn test_chained_ops() {
let a = arr(vec![1.0, 2.0, 3.0]);
let b = arr(vec![4.0, 5.0, 6.0]);
let c = arr(vec![10.0, 10.0, 10.0]);
let result = (&(&a + &b).unwrap() * &c).unwrap();
assert_eq!(result.as_slice().unwrap(), &[50.0, 70.0, 90.0]);
}
use crate::dimension::{Ix2, Ix3, IxDyn};
#[test]
fn test_broadcast_2d_row_plus_column() {
let col = Array::<f64, Ix2>::from_vec(Ix2::new([3, 1]), vec![1.0, 2.0, 3.0]).unwrap();
let row =
Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
let result = (&col + &row).unwrap();
assert_eq!(result.shape(), &[3, 4]);
assert_eq!(
result.as_slice().unwrap(),
&[
11.0, 21.0, 31.0, 41.0, 12.0, 22.0, 32.0, 42.0, 13.0, 23.0, 33.0, 43.0, ]
);
}
#[test]
fn test_broadcast_2d_stretch_one_axis() {
let a = Array::<f64, Ix2>::from_vec(
Ix2::new([3, 4]),
vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
)
.unwrap();
let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 4]), vec![100.0, 200.0, 300.0, 400.0])
.unwrap();
let result = (&a + &b).unwrap();
assert_eq!(result.shape(), &[3, 4]);
assert_eq!(
result.as_slice().unwrap(),
&[
101.0, 202.0, 303.0, 404.0, 105.0, 206.0, 307.0, 408.0, 109.0, 210.0, 311.0, 412.0,
]
);
}
#[test]
fn test_broadcast_3d_with_2d_axis() {
let a =
Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), (1..=24).map(|i| i as f64).collect())
.unwrap();
let b =
Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 4]), (1..=12).map(|i| i as f64).collect())
.unwrap();
let result = (&a - &b).unwrap();
assert_eq!(result.shape(), &[2, 3, 4]);
let first_half: Vec<f64> = (1..=12).map(|_| 0.0).collect();
assert_eq!(&result.as_slice().unwrap()[..12], &first_half[..]);
let second_half: Vec<f64> = (0..12).map(|_| 12.0).collect();
assert_eq!(&result.as_slice().unwrap()[12..], &second_half[..]);
}
#[test]
fn test_broadcast_incompatible_shapes_error() {
let a = arr(vec![1.0, 2.0, 3.0]);
let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
let result = &a + &b;
assert!(result.is_err());
let c = Array::<f64, Ix2>::from_vec(Ix2::new([3, 4]), vec![0.0; 12]).unwrap();
let d = Array::<f64, Ix2>::from_vec(Ix2::new([3, 5]), vec![0.0; 15]).unwrap();
assert!((&c + &d).is_err());
}
#[test]
fn test_broadcast_mul_2d() {
let col = Array::<i32, Ix2>::from_vec(Ix2::new([3, 1]), vec![1, 2, 3]).unwrap();
let row = Array::<i32, Ix2>::from_vec(Ix2::new([1, 3]), vec![10, 20, 30]).unwrap();
let result = (&col * &row).unwrap();
assert_eq!(result.shape(), &[3, 3]);
assert_eq!(
result.as_slice().unwrap(),
&[10, 20, 30, 20, 40, 60, 30, 60, 90]
);
}
#[test]
fn test_add_broadcast_1d_plus_2d() {
let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let m =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
.unwrap();
let result = v.add_broadcast(&m).unwrap();
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(
result.as_slice().unwrap(),
&[11.0, 22.0, 33.0, 41.0, 52.0, 63.0]
);
}
#[test]
fn test_add_broadcast_1d_plus_column() {
let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let col = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
let result = v.add_broadcast(&col).unwrap();
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(
result.as_slice().unwrap(),
&[11.0, 12.0, 13.0, 21.0, 22.0, 23.0]
);
}
#[test]
fn test_sub_broadcast_2d_minus_1d() {
let m =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
.unwrap();
let v = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let result = m.sub_broadcast(&v).unwrap();
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(
result.as_slice().unwrap(),
&[9.0, 18.0, 27.0, 39.0, 48.0, 57.0]
);
}
#[test]
fn test_mul_broadcast_returns_dyn() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![10.0, 20.0]).unwrap();
let result: Array<f64, IxDyn> = a.mul_broadcast(&b).unwrap();
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(
result.as_slice().unwrap(),
&[10.0, 20.0, 30.0, 20.0, 40.0, 60.0]
);
}
#[test]
fn test_div_broadcast_incompatible() {
let a = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
assert!(a.div_broadcast(&b).is_err());
}
#[test]
fn test_rem_broadcast_2d() {
let a =
Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![10, 20, 30, 40, 50, 60]).unwrap();
let b = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![3, 7, 11]).unwrap();
let result = a.rem_broadcast(&b).unwrap();
assert_eq!(result.shape(), &[2, 3]);
assert_eq!(
result.as_slice().unwrap(),
&[10 % 3, 20 % 7, 30 % 11, 40 % 3, 50 % 7, 60 % 11]
);
}
#[test]
fn scalar_add_assign_mutates_in_place() {
let mut a = arr(vec![1.0, 2.0, 3.0]);
a += 10.0;
assert_eq!(a.as_slice().unwrap(), &[11.0, 12.0, 13.0]);
}
#[test]
fn scalar_sub_mul_div_rem_assign() {
let mut a = arr(vec![10.0, 20.0, 30.0]);
a -= 1.0;
assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
a *= 2.0;
assert_eq!(a.as_slice().unwrap(), &[18.0, 38.0, 58.0]);
a /= 2.0;
assert_eq!(a.as_slice().unwrap(), &[9.0, 19.0, 29.0]);
let mut b = arr_i32(vec![10, 11, 12]);
b %= 3;
assert_eq!(b.as_slice().unwrap(), &[1, 2, 0]);
}
#[test]
fn scalar_assign_preserves_shape_ix2() {
let mut a =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
a += 1.0;
assert_eq!(a.shape(), &[2, 3]);
assert_eq!(a.as_slice().unwrap(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
}
#[test]
fn add_inplace_same_shape_fast_path() {
let mut a = arr(vec![1.0, 2.0, 3.0]);
let b = arr(vec![10.0, 20.0, 30.0]);
a.add_inplace(&b).unwrap();
assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0]);
}
#[test]
fn sub_mul_div_rem_inplace_same_shape() {
let mut a = arr(vec![10.0, 20.0, 30.0]);
let b = arr(vec![1.0, 2.0, 3.0]);
a.sub_inplace(&b).unwrap();
assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
a.mul_inplace(&b).unwrap();
assert_eq!(a.as_slice().unwrap(), &[9.0, 36.0, 81.0]);
a.div_inplace(&b).unwrap();
assert_eq!(a.as_slice().unwrap(), &[9.0, 18.0, 27.0]);
let mut c = arr_i32(vec![10, 20, 30]);
let d = arr_i32(vec![3, 7, 11]);
c.rem_inplace(&d).unwrap();
assert_eq!(c.as_slice().unwrap(), &[1, 6, 8]);
}
#[test]
fn add_inplace_broadcasts_rhs_into_lhs_shape() {
let mut a =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let b = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
a.add_inplace(&b).unwrap();
assert_eq!(a.shape(), &[2, 3]);
assert_eq!(a.as_slice().unwrap(), &[11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
}
#[test]
fn add_inplace_broadcasts_column_into_rows() {
let mut a =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 1]), vec![100.0, 200.0]).unwrap();
a.add_inplace(&b).unwrap();
assert_eq!(
a.as_slice().unwrap(),
&[101.0, 102.0, 103.0, 204.0, 205.0, 206.0]
);
}
#[test]
fn add_inplace_rejects_incompatible_rhs() {
let mut a = arr(vec![1.0, 2.0, 3.0]);
let b = arr(vec![1.0, 2.0, 3.0, 4.0]);
assert!(a.add_inplace(&b).is_err());
assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn add_inplace_rejects_growing_shape() {
let mut a = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
let b = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0; 6]).unwrap();
assert!(a.add_inplace(&b).is_err());
assert_eq!(a.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn copyto_same_shape_fast_path() {
let mut dst = arr(vec![0.0, 0.0, 0.0]);
let src = arr(vec![1.0, 2.0, 3.0]);
copyto(&mut dst, &src).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn copyto_broadcasts_row_into_matrix() {
let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
let src = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![10.0, 20.0, 30.0]).unwrap();
copyto(&mut dst, &src).unwrap();
assert_eq!(
dst.as_slice().unwrap(),
&[10.0, 20.0, 30.0, 10.0, 20.0, 30.0]
);
}
#[test]
fn copyto_broadcasts_cross_rank_src() {
let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
let src = arr(vec![7.0, 8.0, 9.0]);
copyto(&mut dst, &src).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[7.0, 8.0, 9.0, 7.0, 8.0, 9.0]);
}
#[test]
fn copyto_scalar_src_broadcasts_to_full_dst() {
let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
let src = arr(vec![42.0]);
copyto(&mut dst, &src).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[42.0; 6]);
}
#[test]
fn copyto_rejects_growing_dst() {
let mut dst = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
let src = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![99.0; 6]).unwrap();
assert!(copyto(&mut dst, &src).is_err());
assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn copyto_rejects_incompatible_shapes() {
let mut dst = arr(vec![1.0, 2.0, 3.0]);
let src = arr(vec![1.0, 2.0, 3.0, 4.0]);
assert!(copyto(&mut dst, &src).is_err());
assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn copyto_method_form_equivalent_to_function() {
let mut dst = arr(vec![0.0, 0.0, 0.0]);
let src = arr(vec![1.0, 2.0, 3.0]);
dst.copy_from(&src).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn copyto_works_for_non_copy_element_type_i64() {
let mut dst = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![0, 0, 0, 0]).unwrap();
let src = Array::<i64, Ix1>::from_vec(Ix1::new([4]), vec![1, 2, 3, 4]).unwrap();
copyto(&mut dst, &src).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[1, 2, 3, 4]);
}
#[test]
fn copyto_where_same_shape_only_writes_masked_positions() {
let mut dst = arr(vec![1.0, 2.0, 3.0, 4.0]);
let src = arr(vec![10.0, 20.0, 30.0, 40.0]);
let mask =
Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
copyto_where(&mut dst, &src, &mask).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 4.0]);
}
#[test]
fn copyto_where_broadcasts_mask_across_dst() {
let mut dst =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let src =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0])
.unwrap();
let mask = Array::<bool, Ix2>::from_vec(Ix2::new([1, 3]), vec![true, false, true]).unwrap();
copyto_where(&mut dst, &src, &mask).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[10.0, 2.0, 30.0, 40.0, 5.0, 60.0]);
}
#[test]
fn copyto_where_broadcasts_scalar_src_with_mask() {
let mut dst =
Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
let src = arr(vec![99.0]);
let mask = Array::<bool, Ix2>::from_vec(
Ix2::new([2, 3]),
vec![true, false, true, false, true, false],
)
.unwrap();
copyto_where(&mut dst, &src, &mask).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[99.0, 2.0, 99.0, 4.0, 99.0, 6.0]);
}
#[test]
fn copyto_where_all_false_mask_is_noop() {
let mut dst = arr(vec![1.0, 2.0, 3.0]);
let original = dst.as_slice().unwrap().to_vec();
let src = arr(vec![99.0, 99.0, 99.0]);
let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
copyto_where(&mut dst, &src, &mask).unwrap();
assert_eq!(dst.as_slice().unwrap(), &original[..]);
}
#[test]
fn copyto_where_all_true_mask_matches_copyto() {
let mut dst = arr(vec![0.0, 0.0, 0.0]);
let src = arr(vec![1.0, 2.0, 3.0]);
let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
copyto_where(&mut dst, &src, &mask).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn copyto_where_rejects_incompatible_src_shape() {
let mut dst = arr(vec![1.0, 2.0, 3.0]);
let src = arr(vec![1.0, 2.0]);
let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
assert!(copyto_where(&mut dst, &src, &mask).is_err());
assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn copyto_where_rejects_incompatible_mask_shape() {
let mut dst = arr(vec![1.0, 2.0, 3.0]);
let src = arr(vec![10.0, 20.0, 30.0]);
let mask = Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true; 4]).unwrap();
assert!(copyto_where(&mut dst, &src, &mask).is_err());
assert_eq!(dst.as_slice().unwrap(), &[1.0, 2.0, 3.0]);
}
#[test]
fn copy_from_where_method_form_equivalent() {
let mut dst = arr(vec![1.0, 2.0, 3.0]);
let src = arr(vec![10.0, 20.0, 30.0]);
let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, true, false]).unwrap();
dst.copy_from_where(&src, &mask).unwrap();
assert_eq!(dst.as_slice().unwrap(), &[1.0, 20.0, 3.0]);
}
#[test]
fn div_inplace_by_zero_yields_ieee_sentinels() {
let mut a = arr(vec![1.0, 2.0, 0.0]);
let b = arr(vec![2.0, 0.0, 0.0]);
a.div_inplace(&b).unwrap();
let s = a.as_slice().unwrap();
assert_eq!(s[0], 0.5);
assert!(s[1].is_infinite() && s[1].is_sign_positive());
assert!(s[2].is_nan());
}
}