use std::sync::{Arc, OnceLock};
use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
pub struct MaskedArray<T: Element, D: Dimension> {
data: Array<T, D>,
mask: Arc<OnceLock<Array<bool, D>>>,
real_mask: bool,
pub(crate) hard_mask: bool,
pub(crate) fill_value: T,
}
impl<T: Element, D: Dimension> std::fmt::Debug for MaskedArray<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MaskedArray")
.field("data", &self.data)
.field("real_mask", &self.real_mask)
.field("hard_mask", &self.hard_mask)
.field("fill_value", &self.fill_value)
.finish_non_exhaustive()
}
}
impl<T: Element + Clone, D: Dimension> Clone for MaskedArray<T, D> {
fn clone(&self) -> Self {
Self {
data: self.data.clone(),
mask: Arc::clone(&self.mask),
real_mask: self.real_mask,
hard_mask: self.hard_mask,
fill_value: self.fill_value.clone(),
}
}
}
impl<T: Element, D: Dimension> MaskedArray<T, D> {
pub fn new(data: Array<T, D>, mask: Array<bool, D>) -> FerrayResult<Self> {
if data.shape() != mask.shape() {
return Err(FerrayError::shape_mismatch(format!(
"MaskedArray::new: data shape {:?} does not match mask shape {:?}",
data.shape(),
mask.shape()
)));
}
let lock = OnceLock::new();
let _ = lock.set(mask);
Ok(Self {
data,
mask: Arc::new(lock),
real_mask: true,
hard_mask: false,
fill_value: T::zero(),
})
}
pub fn from_data(data: Array<T, D>) -> FerrayResult<Self> {
Ok(Self {
data,
mask: Arc::new(OnceLock::new()),
real_mask: false,
hard_mask: false,
fill_value: T::zero(),
})
}
#[inline]
pub const fn has_real_mask(&self) -> bool {
self.real_mask
}
#[inline]
pub const fn fill_value(&self) -> T
where
T: Copy,
{
self.fill_value
}
#[must_use]
pub fn with_fill_value(mut self, fill_value: T) -> Self {
self.fill_value = fill_value;
self
}
pub fn set_fill_value(&mut self, fill_value: T) {
self.fill_value = fill_value;
}
#[inline]
pub const fn data(&self) -> &Array<T, D> {
&self.data
}
pub fn mask(&self) -> &Array<bool, D> {
self.mask.get_or_init(|| {
Array::<bool, D>::from_elem(self.data.dim().clone(), false)
.expect("from_elem with matching dim cannot fail")
})
}
#[inline]
pub fn mask_opt(&self) -> Option<&Array<bool, D>> {
if self.real_mask {
self.mask.get()
} else {
None
}
}
#[inline]
pub const fn data_mut(&mut self) -> &mut Array<T, D> {
&mut self.data
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
#[inline]
pub fn ndim(&self) -> usize {
self.data.ndim()
}
#[inline]
pub fn size(&self) -> usize {
self.data.size()
}
#[inline]
pub const fn dim(&self) -> &D {
self.data.dim()
}
#[inline]
pub const fn is_hard_mask(&self) -> bool {
self.hard_mask
}
fn ensure_materialized_mut(&mut self) -> &mut Array<bool, D> {
if !self.real_mask || self.mask.get().is_none() {
let fresh = Array::<bool, D>::from_elem(self.data.dim().clone(), false)
.expect("from_elem with matching dim cannot fail");
let lock = OnceLock::new();
let _ = lock.set(fresh);
self.mask = Arc::new(lock);
self.real_mask = true;
}
if Arc::get_mut(&mut self.mask).is_none() {
let cloned_mask = self
.mask
.get()
.expect("real_mask implies OnceLock set")
.clone();
let new_lock = OnceLock::new();
let _ = new_lock.set(cloned_mask);
self.mask = Arc::new(new_lock);
}
Arc::get_mut(&mut self.mask)
.expect("just made the Arc unique above")
.get_mut()
.expect("OnceLock was initialized above")
}
pub fn set_mask_flat(&mut self, flat_idx: usize, value: bool) -> FerrayResult<()> {
let size = self.size();
if flat_idx >= size {
return Err(FerrayError::index_out_of_bounds(flat_idx as isize, 0, size));
}
if self.hard_mask && !value {
return Ok(());
}
if !self.real_mask && !value {
return Ok(());
}
let mask = self.ensure_materialized_mut();
if let Some(slice) = mask.as_slice_mut() {
slice[flat_idx] = value;
} else {
if let Some(m) = mask.iter_mut().nth(flat_idx) {
*m = value;
}
}
Ok(())
}
pub fn set_mask(&mut self, new_mask: Array<bool, D>) -> FerrayResult<()> {
if self.data.shape() != new_mask.shape() {
return Err(FerrayError::shape_mismatch(format!(
"set_mask: mask shape {:?} does not match array shape {:?}",
new_mask.shape(),
self.data.shape()
)));
}
if self.hard_mask && self.real_mask {
let existing = self.mask.get().expect("real_mask implies OnceLock set");
let merged: Vec<bool> = existing
.iter()
.zip(new_mask.iter())
.map(|(old, new)| *old || *new)
.collect();
let merged_arr = Array::from_vec(self.data.dim().clone(), merged)?;
let lock = OnceLock::new();
let _ = lock.set(merged_arr);
self.mask = Arc::new(lock);
} else {
let lock = OnceLock::new();
let _ = lock.set(new_mask);
self.mask = Arc::new(lock);
}
self.real_mask = true;
Ok(())
}
#[inline]
pub fn shares_mask(&self) -> bool {
Arc::strong_count(&self.mask) > 1
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::Array;
use ferray_core::dimension::Ix1;
fn arr_f64(data: Vec<f64>) -> Array<f64, Ix1> {
let n = data.len();
Array::<f64, Ix1>::from_vec(Ix1::new([n]), data).unwrap()
}
fn arr_bool(data: Vec<bool>) -> Array<bool, Ix1> {
let n = data.len();
Array::<bool, Ix1>::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn from_data_starts_in_nomask_sentinel_state() {
let ma = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
assert!(!ma.has_real_mask());
assert!(ma.mask_opt().is_none());
}
#[test]
fn new_with_explicit_mask_is_real_mask() {
let ma = MaskedArray::new(
arr_f64(vec![1.0, 2.0, 3.0]),
arr_bool(vec![false, true, false]),
)
.unwrap();
assert!(ma.has_real_mask());
assert!(ma.mask_opt().is_some());
}
#[test]
fn mask_accessor_lazily_materializes_nomask_sentinel() {
let ma = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
assert!(ma.mask_opt().is_none());
let m = ma.mask();
assert_eq!(m.shape(), &[3]);
assert_eq!(
m.iter().copied().collect::<Vec<_>>(),
vec![false, false, false]
);
let m2 = ma.mask();
assert_eq!(std::ptr::from_ref(m), std::ptr::from_ref(m2));
assert!(!ma.has_real_mask());
}
#[test]
fn set_mask_flat_false_on_nomask_stays_zero_allocation() {
let mut ma = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
ma.set_mask_flat(1, false).unwrap();
assert!(!ma.has_real_mask());
assert!(ma.mask_opt().is_none());
}
#[test]
fn set_mask_flat_true_on_nomask_materializes_and_promotes() {
let mut ma = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
ma.set_mask_flat(1, true).unwrap();
assert!(ma.has_real_mask());
let m: Vec<bool> = ma.mask().iter().copied().collect();
assert_eq!(m, vec![false, true, false]);
}
#[test]
fn set_mask_promotes_and_keeps_provided_values() {
let mut ma = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
assert!(!ma.has_real_mask());
ma.set_mask(arr_bool(vec![true, false, true])).unwrap();
assert!(ma.has_real_mask());
assert_eq!(
ma.mask().iter().copied().collect::<Vec<_>>(),
vec![true, false, true]
);
}
#[test]
fn set_mask_shape_mismatch_errors() {
let mut ma = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
assert!(ma.set_mask(arr_bool(vec![false; 4])).is_err());
}
#[test]
fn clone_preserves_nomask_sentinel_state() {
let ma = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
let cloned = ma;
assert!(!cloned.has_real_mask());
assert!(cloned.mask_opt().is_none());
}
#[test]
fn clone_after_materialization_copies_the_mask() {
let ma = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
let _ = ma.mask();
let cloned = ma;
assert_eq!(
cloned.mask().iter().copied().collect::<Vec<_>>(),
vec![false, false, false]
);
}
#[test]
fn clone_preserves_real_mask_state() {
let ma = MaskedArray::new(
arr_f64(vec![1.0, 2.0, 3.0]),
arr_bool(vec![false, true, false]),
)
.unwrap();
let cloned = ma;
assert!(cloned.has_real_mask());
assert_eq!(
cloned.mask().iter().copied().collect::<Vec<_>>(),
vec![false, true, false]
);
}
#[test]
fn clone_shares_mask_via_arc() {
let ma = MaskedArray::new(
arr_f64(vec![1.0, 2.0, 3.0]),
arr_bool(vec![false, true, false]),
)
.unwrap();
let cloned = ma.clone();
assert!(ma.shares_mask());
assert!(cloned.shares_mask());
}
#[test]
fn unique_masked_array_does_not_share() {
let ma = MaskedArray::new(
arr_f64(vec![1.0, 2.0, 3.0]),
arr_bool(vec![false, true, false]),
)
.unwrap();
assert!(!ma.shares_mask());
}
#[test]
fn copy_on_write_isolates_parent_from_child_mutation() {
let parent = MaskedArray::new(
arr_f64(vec![1.0, 2.0, 3.0]),
arr_bool(vec![false, false, false]),
)
.unwrap();
let mut child = parent.clone();
assert!(parent.shares_mask());
assert!(child.shares_mask());
child.set_mask_flat(1, true).unwrap();
assert_eq!(
parent.mask().iter().copied().collect::<Vec<_>>(),
vec![false, false, false]
);
assert_eq!(
child.mask().iter().copied().collect::<Vec<_>>(),
vec![false, true, false]
);
assert!(!parent.shares_mask());
assert!(!child.shares_mask());
}
#[test]
fn copy_on_write_via_set_mask() {
let parent = MaskedArray::new(
arr_f64(vec![1.0, 2.0, 3.0]),
arr_bool(vec![false, false, false]),
)
.unwrap();
let mut child = parent.clone();
assert!(parent.shares_mask());
child.set_mask(arr_bool(vec![true, true, true])).unwrap();
assert_eq!(
parent.mask().iter().copied().collect::<Vec<_>>(),
vec![false, false, false]
);
assert_eq!(
child.mask().iter().copied().collect::<Vec<_>>(),
vec![true, true, true]
);
assert!(!parent.shares_mask());
}
#[test]
fn nomask_sentinel_clones_share_empty_arc() {
let parent = MaskedArray::from_data(arr_f64(vec![1.0, 2.0, 3.0])).unwrap();
let cloned = parent.clone();
assert!(parent.shares_mask());
assert!(cloned.shares_mask());
assert!(!parent.has_real_mask());
assert!(!cloned.has_real_mask());
}
#[test]
fn hard_mask_union_on_real_mask() {
let mut ma = MaskedArray::new(
arr_f64(vec![1.0, 2.0, 3.0]),
arr_bool(vec![true, false, false]),
)
.unwrap();
ma.harden_mask().unwrap();
ma.set_mask(arr_bool(vec![false, false, true])).unwrap();
assert_eq!(
ma.mask().iter().copied().collect::<Vec<_>>(),
vec![true, false, true]
);
}
}