use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::FerrayResult;
use num_traits::Float;
use crate::MaskedArray;
pub fn masked_where<T: Element + Copy, D: Dimension>(
condition: &Array<bool, D>,
data: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>> {
MaskedArray::new(data.clone(), condition.clone())
}
pub fn masked_invalid<T: Element + Float, D: Dimension>(
data: &Array<T, D>,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| v.is_nan() || v.is_infinite()).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
pub fn fix_invalid<T: Element + Float, D: Dimension>(
data: &Array<T, D>,
fill_value: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mut new_data: Vec<T> = Vec::with_capacity(data.size());
let mut new_mask: Vec<bool> = Vec::with_capacity(data.size());
for &v in data.iter() {
if v.is_nan() || v.is_infinite() {
new_data.push(fill_value);
new_mask.push(true);
} else {
new_data.push(v);
new_mask.push(false);
}
}
let data_arr = Array::from_vec(data.dim().clone(), new_data)?;
let mask_arr = Array::from_vec(data.dim().clone(), new_mask)?;
let mut out = MaskedArray::new(data_arr, mask_arr)?;
out.set_fill_value(fill_value);
Ok(out)
}
pub fn masked_equal<T: Element + PartialEq + Copy, D: Dimension>(
data: &Array<T, D>,
value: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| *v == value).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
pub fn masked_not_equal<T: Element + PartialEq + Copy, D: Dimension>(
data: &Array<T, D>,
value: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| *v != value).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
pub fn masked_greater<T: Element + PartialOrd + Copy, D: Dimension>(
data: &Array<T, D>,
value: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| *v > value).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
pub fn masked_less<T: Element + PartialOrd + Copy, D: Dimension>(
data: &Array<T, D>,
value: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| *v < value).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
pub fn masked_greater_equal<T: Element + PartialOrd + Copy, D: Dimension>(
data: &Array<T, D>,
value: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| *v >= value).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
pub fn masked_less_equal<T: Element + PartialOrd + Copy, D: Dimension>(
data: &Array<T, D>,
value: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| *v <= value).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
pub fn masked_inside<T: Element + PartialOrd + Copy, D: Dimension>(
data: &Array<T, D>,
v1: T,
v2: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| *v >= v1 && *v <= v2).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
pub fn masked_outside<T: Element + PartialOrd + Copy, D: Dimension>(
data: &Array<T, D>,
v1: T,
v2: T,
) -> FerrayResult<MaskedArray<T, D>> {
let mask_data: Vec<bool> = data.iter().map(|v| *v < v1 || *v > v2).collect();
let mask = Array::from_vec(data.dim().clone(), mask_data)?;
MaskedArray::new(data.clone(), mask)
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::Ix1;
#[test]
fn fix_invalid_masks_and_replaces_nan_and_inf() {
let data = Array::<f64, Ix1>::from_vec(
Ix1::new([6]),
vec![1.0, f64::NAN, 3.0, f64::INFINITY, f64::NEG_INFINITY, 6.0],
)
.unwrap();
let ma = fix_invalid(&data, -99.0).unwrap();
let m: Vec<bool> = ma.mask().iter().copied().collect();
assert_eq!(m, vec![false, true, false, true, true, false]);
let d: Vec<f64> = ma.data().iter().copied().collect();
assert_eq!(d, vec![1.0, -99.0, 3.0, -99.0, -99.0, 6.0]);
assert_eq!(ma.fill_value(), -99.0);
}
#[test]
fn fix_invalid_preserves_valid_values() {
let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let ma = fix_invalid(&data, 0.0).unwrap();
assert_eq!(
ma.mask().iter().copied().collect::<Vec<_>>(),
vec![false, false, false, false]
);
assert_eq!(
ma.data().iter().copied().collect::<Vec<_>>(),
vec![1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn fix_invalid_all_nan_input() {
let data =
Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![f64::NAN, f64::NAN, f64::NAN]).unwrap();
let ma = fix_invalid(&data, 0.0).unwrap();
assert_eq!(
ma.mask().iter().copied().collect::<Vec<_>>(),
vec![true, true, true]
);
assert_eq!(
ma.data().iter().copied().collect::<Vec<_>>(),
vec![0.0, 0.0, 0.0]
);
assert!(ma.data().iter().all(|v| !v.is_nan()));
}
#[test]
fn fix_invalid_vs_masked_invalid_data_difference() {
let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
let via_masked = masked_invalid(&data).unwrap();
let via_fixed = fix_invalid(&data, -1.0).unwrap();
assert_eq!(
via_masked.mask().iter().copied().collect::<Vec<_>>(),
via_fixed.mask().iter().copied().collect::<Vec<_>>()
);
assert!(via_masked.data().iter().nth(1).unwrap().is_nan());
assert_eq!(*via_fixed.data().iter().nth(1).unwrap(), -1.0);
}
#[test]
fn fix_invalid_2d_shape_preserved() {
use ferray_core::dimension::Ix2;
let data = Array::<f64, Ix2>::from_vec(
Ix2::new([2, 3]),
vec![1.0, f64::NAN, 3.0, 4.0, 5.0, f64::INFINITY],
)
.unwrap();
let ma = fix_invalid(&data, -1.0).unwrap();
assert_eq!(ma.shape(), &[2, 3]);
assert_eq!(
ma.mask().iter().copied().collect::<Vec<_>>(),
vec![false, true, false, false, false, true]
);
}
}