use std::ops::{Add, Div, Mul, Sub};
use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::MaskedArray;
fn mask_union<D: Dimension>(
a: &Array<bool, D>,
b: &Array<bool, D>,
) -> FerrayResult<Array<bool, D>> {
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"mask_union: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let data: Vec<bool> = a.iter().zip(b.iter()).map(|(x, y)| *x || *y).collect();
Array::from_vec(a.dim().clone(), data)
}
pub fn masked_add<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Add<Output = T> + Copy,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"masked_add: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let result_mask = mask_union(a.mask(), b.mask())?;
let data: Vec<T> = a
.data()
.iter()
.zip(b.data().iter())
.zip(result_mask.iter())
.map(|((x, y), m)| if *m { T::zero() } else { *x + *y })
.collect();
let result_data = Array::from_vec(a.dim().clone(), data)?;
MaskedArray::new(result_data, result_mask)
}
pub fn masked_sub<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Sub<Output = T> + Copy,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"masked_sub: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let result_mask = mask_union(a.mask(), b.mask())?;
let data: Vec<T> = a
.data()
.iter()
.zip(b.data().iter())
.zip(result_mask.iter())
.map(|((x, y), m)| if *m { T::zero() } else { *x - *y })
.collect();
let result_data = Array::from_vec(a.dim().clone(), data)?;
MaskedArray::new(result_data, result_mask)
}
pub fn masked_mul<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Mul<Output = T> + Copy,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"masked_mul: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let result_mask = mask_union(a.mask(), b.mask())?;
let data: Vec<T> = a
.data()
.iter()
.zip(b.data().iter())
.zip(result_mask.iter())
.map(|((x, y), m)| if *m { T::zero() } else { *x * *y })
.collect();
let result_data = Array::from_vec(a.dim().clone(), data)?;
MaskedArray::new(result_data, result_mask)
}
pub fn masked_div<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Div<Output = T> + Copy,
D: Dimension,
{
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"masked_div: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let result_mask = mask_union(a.mask(), b.mask())?;
let data: Vec<T> = a
.data()
.iter()
.zip(b.data().iter())
.zip(result_mask.iter())
.map(|((x, y), m)| if *m { T::zero() } else { *x / *y })
.collect();
let result_data = Array::from_vec(a.dim().clone(), data)?;
MaskedArray::new(result_data, result_mask)
}
pub fn masked_add_array<T, D>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Add<Output = T> + Copy,
D: Dimension,
{
if ma.shape() != arr.shape() {
return Err(FerrayError::shape_mismatch(format!(
"masked_add_array: shapes {:?} and {:?} differ",
ma.shape(),
arr.shape()
)));
}
let data: Vec<T> = ma
.data()
.iter()
.zip(arr.iter())
.zip(ma.mask().iter())
.map(|((x, y), m)| if *m { T::zero() } else { *x + *y })
.collect();
let result_data = Array::from_vec(ma.dim().clone(), data)?;
MaskedArray::new(result_data, ma.mask().clone())
}
pub fn masked_sub_array<T, D>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Sub<Output = T> + Copy,
D: Dimension,
{
if ma.shape() != arr.shape() {
return Err(FerrayError::shape_mismatch(format!(
"masked_sub_array: shapes {:?} and {:?} differ",
ma.shape(),
arr.shape()
)));
}
let data: Vec<T> = ma
.data()
.iter()
.zip(arr.iter())
.zip(ma.mask().iter())
.map(|((x, y), m)| if *m { T::zero() } else { *x - *y })
.collect();
let result_data = Array::from_vec(ma.dim().clone(), data)?;
MaskedArray::new(result_data, ma.mask().clone())
}
pub fn masked_mul_array<T, D>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Mul<Output = T> + Copy,
D: Dimension,
{
if ma.shape() != arr.shape() {
return Err(FerrayError::shape_mismatch(format!(
"masked_mul_array: shapes {:?} and {:?} differ",
ma.shape(),
arr.shape()
)));
}
let data: Vec<T> = ma
.data()
.iter()
.zip(arr.iter())
.zip(ma.mask().iter())
.map(|((x, y), m)| if *m { T::zero() } else { *x * *y })
.collect();
let result_data = Array::from_vec(ma.dim().clone(), data)?;
MaskedArray::new(result_data, ma.mask().clone())
}
pub fn masked_div_array<T, D>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Div<Output = T> + Copy,
D: Dimension,
{
if ma.shape() != arr.shape() {
return Err(FerrayError::shape_mismatch(format!(
"masked_div_array: shapes {:?} and {:?} differ",
ma.shape(),
arr.shape()
)));
}
let data: Vec<T> = ma
.data()
.iter()
.zip(arr.iter())
.zip(ma.mask().iter())
.map(|((x, y), m)| if *m { T::zero() } else { *x / *y })
.collect();
let result_data = Array::from_vec(ma.dim().clone(), data)?;
MaskedArray::new(result_data, ma.mask().clone())
}