use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use num_traits::Float;
use crate::MaskedArray;
fn masked_unary_op<T, D>(
ma: &MaskedArray<T, D>,
f: impl Fn(T) -> T,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Copy,
D: Dimension,
{
let data: Vec<T> = ma
.data()
.iter()
.zip(ma.mask().iter())
.map(|(v, m)| if *m { T::zero() } else { f(*v) })
.collect();
let result_data = Array::from_vec(ma.dim().clone(), data)?;
MaskedArray::new(result_data, ma.mask().clone())
}
fn masked_binary_op<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
f: impl Fn(T, T) -> T,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Copy,
D: Dimension,
{
let mask_data: Vec<bool> = a
.mask()
.iter()
.zip(b.mask().iter())
.map(|(ma, mb)| *ma || *mb)
.collect();
let result_mask = Array::from_vec(a.dim().clone(), mask_data)?;
let data: Vec<T> = a
.data()
.iter()
.zip(b.data().iter())
.zip(result_mask.iter())
.map(|((x, y), m)| if *m { T::zero() } else { f(*x, *y) })
.collect();
let result_data = Array::from_vec(a.dim().clone(), data)?;
MaskedArray::new(result_data, result_mask)
}
pub fn sin<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::sin)
}
pub fn cos<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::cos)
}
pub fn tan<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::tan)
}
pub fn arcsin<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::asin)
}
pub fn arccos<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::acos)
}
pub fn arctan<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::atan)
}
pub fn exp<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::exp)
}
pub fn exp2<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::exp2)
}
pub fn log<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::ln)
}
pub fn log2<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::log2)
}
pub fn log10<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::log10)
}
pub fn floor<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::floor)
}
pub fn ceil<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::ceil)
}
pub fn sqrt<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::sqrt)
}
pub fn absolute<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::abs)
}
pub fn negative<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::neg)
}
pub fn reciprocal<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, T::recip)
}
pub fn square<T, D>(ma: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_unary_op(ma, |v| v * v)
}
pub fn add<T, D>(a: &MaskedArray<T, D>, b: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_binary_op(a, b, |x, y| x + y)
}
pub fn subtract<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_binary_op(a, b, |x, y| x - y)
}
pub fn multiply<T, D>(
a: &MaskedArray<T, D>,
b: &MaskedArray<T, D>,
) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_binary_op(a, b, |x, y| x * y)
}
pub fn divide<T, D>(a: &MaskedArray<T, D>, b: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_binary_op(a, b, |x, y| x / y)
}
pub fn power<T, D>(a: &MaskedArray<T, D>, b: &MaskedArray<T, D>) -> FerrayResult<MaskedArray<T, D>>
where
T: Element + Float,
D: Dimension,
{
masked_binary_op(a, b, T::powf)
}