use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
#[derive(Debug, Clone)]
pub struct MaskedArray<T: Element, D: Dimension> {
data: Array<T, D>,
mask: Array<bool, D>,
pub(crate) hard_mask: bool,
}
impl<T: Element, D: Dimension> MaskedArray<T, D> {
pub fn new(data: Array<T, D>, mask: Array<bool, D>) -> FerrayResult<Self> {
if data.shape() != mask.shape() {
return Err(FerrayError::shape_mismatch(format!(
"MaskedArray::new: data shape {:?} does not match mask shape {:?}",
data.shape(),
mask.shape()
)));
}
Ok(Self {
data,
mask,
hard_mask: false,
})
}
pub fn from_data(data: Array<T, D>) -> FerrayResult<Self> {
let mask = Array::<bool, D>::from_elem(data.dim().clone(), false)?;
Ok(Self {
data,
mask,
hard_mask: false,
})
}
#[inline]
pub fn data(&self) -> &Array<T, D> {
&self.data
}
#[inline]
pub fn mask(&self) -> &Array<bool, D> {
&self.mask
}
#[inline]
pub fn data_mut(&mut self) -> &mut Array<T, D> {
&mut self.data
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
#[inline]
pub fn ndim(&self) -> usize {
self.data.ndim()
}
#[inline]
pub fn size(&self) -> usize {
self.data.size()
}
#[inline]
pub fn dim(&self) -> &D {
self.data.dim()
}
#[inline]
pub fn is_hard_mask(&self) -> bool {
self.hard_mask
}
pub fn set_mask_flat(&mut self, flat_idx: usize, value: bool) -> FerrayResult<()> {
let size = self.size();
if flat_idx >= size {
return Err(FerrayError::index_out_of_bounds(flat_idx as isize, 0, size));
}
if self.hard_mask && !value {
return Ok(());
}
if let Some(m) = self.mask.iter_mut().nth(flat_idx) {
*m = value;
}
Ok(())
}
pub fn set_mask(&mut self, new_mask: Array<bool, D>) -> FerrayResult<()> {
if self.mask.shape() != new_mask.shape() {
return Err(FerrayError::shape_mismatch(format!(
"set_mask: mask shape {:?} does not match array shape {:?}",
new_mask.shape(),
self.mask.shape()
)));
}
if self.hard_mask {
let merged: Vec<bool> = self
.mask
.iter()
.zip(new_mask.iter())
.map(|(old, new)| *old || *new)
.collect();
self.mask = Array::from_vec(self.mask.dim().clone(), merged)?;
} else {
self.mask = new_mask;
}
Ok(())
}
}