use std::ops::{Add, Div, Mul, Sub};
use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dimension::broadcast::{broadcast_shapes, broadcast_to};
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::MaskedArray;
fn mask_union<D: Dimension>(
a: &Array<bool, D>,
b: &Array<bool, D>,
) -> FerrayResult<Array<bool, D>> {
if a.shape() != b.shape() {
return Err(FerrayError::shape_mismatch(format!(
"mask_union: shapes {:?} and {:?} differ",
a.shape(),
b.shape()
)));
}
let data: Vec<bool> = a.iter().zip(b.iter()).map(|(x, y)| *x || *y).collect();
Array::from_vec(a.dim().clone(), data)
}
struct BroadcastedPair<T> {
a_data: Vec<T>,
a_mask: Vec<bool>,
b_data: Vec<T>,
b_mask: Vec<bool>,
}
fn broadcast_masked_pair<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
op_name: &str,
) -> FerrayResult<(BroadcastedPair<T>, D)>
where
T: Element + Copy,
D: Dimension,
{
let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
FerrayError::shape_mismatch(format!(
"{}: shapes {:?} and {:?} are not broadcast-compatible",
op_name,
a.shape(),
b.shape()
))
})?;
let a_data_view = broadcast_to(a.data(), &target_shape)?;
let a_mask_view = broadcast_to(a.mask(), &target_shape)?;
let b_data_view = broadcast_to(b.data(), &target_shape)?;
let b_mask_view = broadcast_to(b.mask(), &target_shape)?;
let pair = BroadcastedPair {
a_data: a_data_view.iter().copied().collect(),
a_mask: a_mask_view.iter().copied().collect(),
b_data: b_data_view.iter().copied().collect(),
b_mask: b_mask_view.iter().copied().collect(),
};
let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
FerrayError::shape_mismatch(format!(
"{op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
))
})?;
Ok((pair, result_dim))
}
pub(crate) fn masked_unary_op<T, D, F>(
ma: &MaskedArray<T, D>,
f: F,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Copy,
D: Dimension,
F: Fn(T) -> T,
{
let fill = ma.fill_value;
if !ma.has_real_mask() {
let data: Vec<T> = ma.data().iter().map(|&v| f(v)).collect();
let result_data = Array::from_vec(ma.dim().clone(), data)?;
let mut out = MaskedArray::from_data(result_data)?;
out.fill_value = fill;
return Ok(out);
}
let data: Vec<T> = ma
.data()
.iter()
.zip(ma.mask().iter())
.map(|(v, m)| if *m { fill } else { f(*v) })
.collect();
let result_data = Array::from_vec(ma.dim().clone(), data)?;
let mut out = MaskedArray::new(result_data, ma.mask().clone())?;
out.fill_value = fill;
Ok(out)
}
pub(crate) fn masked_binary_op<T, D, F>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
op: F,
op_name: &str,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
if a.shape() == b.shape() {
let fill = a.fill_value;
if !a.has_real_mask() && !b.has_real_mask() {
let data: Vec<T> = a
.data()
.iter()
.zip(b.data().iter())
.map(|(&x, &y)| op(x, y))
.collect();
let result_data = Array::from_vec(a.dim().clone(), data)?;
let mut result = MaskedArray::from_data(result_data)?;
result.fill_value = fill;
return Ok(result);
}
let result_mask = mask_union(a.mask(), b.mask())?;
let data: Vec<T> = a
.data()
.iter()
.zip(b.data().iter())
.zip(result_mask.iter())
.map(|((x, y), m)| if *m { fill } else { op(*x, *y) })
.collect();
let result_data = Array::from_vec(a.dim().clone(), data)?;
let mut result = MaskedArray::new(result_data, result_mask)?;
result.fill_value = fill;
return Ok(result);
}
let (pair, result_dim) = broadcast_masked_pair(a, b, op_name)?;
let fill = a.fill_value;
let n = pair.a_data.len();
let mut result_data = Vec::with_capacity(n);
let mut result_mask = Vec::with_capacity(n);
for i in 0..n {
let m = pair.a_mask[i] || pair.b_mask[i];
result_mask.push(m);
result_data.push(if m {
fill
} else {
op(pair.a_data[i], pair.b_data[i])
});
}
let data_arr = Array::from_vec(result_dim.clone(), result_data)?;
let mask_arr = Array::from_vec(result_dim, result_mask)?;
let mut out = MaskedArray::new(data_arr, mask_arr)?;
out.fill_value = fill;
Ok(out)
}
pub fn masked_add<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Add<Output = T> + Copy,
D: Dimension,
{
masked_binary_op(a, b, |x, y| x + y, "masked_add")
}
pub fn masked_sub<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Sub<Output = T> + Copy,
D: Dimension,
{
masked_binary_op(a, b, |x, y| x - y, "masked_sub")
}
pub fn masked_mul<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Mul<Output = T> + Copy,
D: Dimension,
{
masked_binary_op(a, b, |x, y| x * y, "masked_mul")
}
pub fn masked_div<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Div<Output = T> + Copy,
D: Dimension,
{
masked_binary_op(a, b, |x, y| x / y, "masked_div")
}
fn masked_array_op<T, D, F>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
op: F,
op_name: &str,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Copy,
D: Dimension,
F: Fn(T, T) -> T,
{
let fill = ma.fill_value;
if ma.shape() == arr.shape() {
let data: Vec<T> = ma
.data()
.iter()
.zip(arr.iter())
.zip(ma.mask().iter())
.map(|((x, y), m)| if *m { fill } else { op(*x, *y) })
.collect();
let result_data = Array::from_vec(ma.dim().clone(), data)?;
let mut out = MaskedArray::new(result_data, ma.mask().clone())?;
out.fill_value = fill;
return Ok(out);
}
let target_shape = broadcast_shapes(ma.shape(), arr.shape()).map_err(|_| {
FerrayError::shape_mismatch(format!(
"{}: shapes {:?} and {:?} are not broadcast-compatible",
op_name,
ma.shape(),
arr.shape()
))
})?;
let ma_data_view = broadcast_to(ma.data(), &target_shape)?;
let ma_mask_view = broadcast_to(ma.mask(), &target_shape)?;
let arr_view = broadcast_to(arr, &target_shape)?;
let ma_data: Vec<T> = ma_data_view.iter().copied().collect();
let ma_mask: Vec<bool> = ma_mask_view.iter().copied().collect();
let arr_data: Vec<T> = arr_view.iter().copied().collect();
let n = ma_data.len();
let mut result_data = Vec::with_capacity(n);
let mut result_mask = Vec::with_capacity(n);
for i in 0..n {
let m = ma_mask[i];
result_mask.push(m);
result_data.push(if m { fill } else { op(ma_data[i], arr_data[i]) });
}
let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
FerrayError::shape_mismatch(format!(
"{op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
))
})?;
let data_arr = Array::from_vec(result_dim.clone(), result_data)?;
let mask_arr = Array::from_vec(result_dim, result_mask)?;
let mut out = MaskedArray::new(data_arr, mask_arr)?;
out.fill_value = fill;
Ok(out)
}
pub fn masked_add_array<T, D>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Add<Output = T> + Copy,
D: Dimension,
{
masked_array_op(ma, arr, |x, y| x + y, "masked_add_array")
}
pub fn masked_sub_array<T, D>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Sub<Output = T> + Copy,
D: Dimension,
{
masked_array_op(ma, arr, |x, y| x - y, "masked_sub_array")
}
pub fn masked_mul_array<T, D>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Mul<Output = T> + Copy,
D: Dimension,
{
masked_array_op(ma, arr, |x, y| x * y, "masked_mul_array")
}
pub fn masked_div_array<T, D>(
ma: &MaskedArray<T, D>,
arr: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Div<Output = T> + Copy,
D: Dimension,
{
masked_array_op(ma, arr, |x, y| x / y, "masked_div_array")
}
impl<T, D> std::ops::Add<&MaskedArray<T, D>> for &MaskedArray<T, D>
where
T: Element + Add<Output = T> + Copy,
D: Dimension,
{
type Output = FerrayResult<MaskedArray<T, D>>;
fn add(self, rhs: &MaskedArray<T, D>) -> Self::Output {
masked_add(self, rhs)
}
}
impl<T, D> std::ops::Sub<&MaskedArray<T, D>> for &MaskedArray<T, D>
where
T: Element + Sub<Output = T> + Copy,
D: Dimension,
{
type Output = FerrayResult<MaskedArray<T, D>>;
fn sub(self, rhs: &MaskedArray<T, D>) -> Self::Output {
masked_sub(self, rhs)
}
}
impl<T, D> std::ops::Mul<&MaskedArray<T, D>> for &MaskedArray<T, D>
where
T: Element + Mul<Output = T> + Copy,
D: Dimension,
{
type Output = FerrayResult<MaskedArray<T, D>>;
fn mul(self, rhs: &MaskedArray<T, D>) -> Self::Output {
masked_mul(self, rhs)
}
}
impl<T, D> std::ops::Div<&MaskedArray<T, D>> for &MaskedArray<T, D>
where
T: Element + Div<Output = T> + Copy,
D: Dimension,
{
type Output = FerrayResult<MaskedArray<T, D>>;
fn div(self, rhs: &MaskedArray<T, D>) -> Self::Output {
masked_div(self, rhs)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
fn ma1d(data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix1> {
let n = data.len();
let d = Array::<f64, Ix1>::from_vec(Ix1::new([n]), data).unwrap();
let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), mask).unwrap();
MaskedArray::new(d, m).unwrap()
}
#[test]
fn masked_div_positive_by_zero_yields_positive_infinity_unmasked() {
let a = ma1d(vec![1.0, 2.0, 3.0], vec![false; 3]);
let b = ma1d(vec![1.0, 0.0, 3.0], vec![false; 3]);
let r = masked_div(&a, &b).unwrap();
let rd: Vec<f64> = r.data().iter().copied().collect();
let rm: Vec<bool> = r.mask().iter().copied().collect();
assert_eq!(rd[0], 1.0);
assert!(rd[1].is_infinite() && rd[1].is_sign_positive());
assert_eq!(rd[2], 1.0);
assert_eq!(rm, vec![false, false, false]);
}
#[test]
fn masked_div_negative_by_zero_yields_negative_infinity_unmasked() {
let a = ma1d(vec![-4.0], vec![false]);
let b = ma1d(vec![0.0], vec![false]);
let r = masked_div(&a, &b).unwrap();
let v = r.data().iter().next().copied().unwrap();
assert!(v.is_infinite() && v.is_sign_negative());
assert!(!r.mask().iter().next().copied().unwrap());
}
#[test]
fn masked_div_zero_by_zero_yields_nan_unmasked() {
let a = ma1d(vec![0.0], vec![false]);
let b = ma1d(vec![0.0], vec![false]);
let r = masked_div(&a, &b).unwrap();
let v = r.data().iter().next().copied().unwrap();
assert!(v.is_nan());
assert!(!r.mask().iter().next().copied().unwrap());
}
#[test]
fn masked_div_skips_op_at_masked_divisor_positions() {
let a = ma1d(vec![1.0, 2.0, 3.0], vec![false; 3]).with_fill_value(-42.0);
let b = ma1d(vec![2.0, 0.0, 4.0], vec![false, true, false]);
let r = masked_div(&a, &b).unwrap();
let rd: Vec<f64> = r.data().iter().copied().collect();
let rm: Vec<bool> = r.mask().iter().copied().collect();
assert_eq!(rd, vec![0.5, -42.0, 0.75]);
assert_eq!(rm, vec![false, true, false]);
assert!(!rd[1].is_infinite() && !rd[1].is_nan());
}
#[test]
fn masked_div_array_by_zero_yields_infinity_unmasked() {
let a = ma1d(vec![5.0, 6.0], vec![false; 2]);
let divisor = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 2.0]).unwrap();
let r = masked_div_array(&a, &divisor).unwrap();
let rd: Vec<f64> = r.data().iter().copied().collect();
assert!(rd[0].is_infinite() && rd[0].is_sign_positive());
assert_eq!(rd[1], 3.0);
let rm: Vec<bool> = r.mask().iter().copied().collect();
assert_eq!(rm, vec![false, false]);
}
}