use ferray_core::Array;
use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::MaskedArray;
impl<T: Element, D: Dimension> MaskedArray<T, D> {
pub const fn harden_mask(&mut self) -> FerrayResult<()> {
self.hard_mask = true;
Ok(())
}
pub const fn soften_mask(&mut self) -> FerrayResult<()> {
self.hard_mask = false;
Ok(())
}
}
pub fn getmask<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<Array<bool, D>> {
Ok(ma.mask().clone())
}
pub fn getdata<T: Element + Copy, D: Dimension>(
ma: &MaskedArray<T, D>,
) -> FerrayResult<Array<T, D>> {
Ok(ma.data().clone())
}
pub fn is_masked<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<bool> {
Ok(ma.mask().iter().any(|m| *m))
}
pub fn count_masked<T: Element, D: Dimension>(ma: &MaskedArray<T, D>) -> FerrayResult<usize> {
let count = ma.mask().iter().filter(|m| **m).count();
Ok(count)
}
pub fn count_masked_axis<T: Element, D: Dimension>(
ma: &MaskedArray<T, D>,
axis: usize,
) -> FerrayResult<ferray_core::Array<u64, ferray_core::dimension::IxDyn>> {
use ferray_core::dimension::IxDyn;
let ndim = ma.ndim();
if axis >= ndim {
return Err(FerrayError::axis_out_of_bounds(axis, ndim));
}
let shape = ma.shape();
let axis_len = shape[axis];
let out_shape: Vec<usize> = shape
.iter()
.enumerate()
.filter_map(|(i, &s)| if i == axis { None } else { Some(s) })
.collect();
let out_size: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let mask_data: Vec<bool> = ma.mask().iter().copied().collect();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let mut out_data: Vec<u64> = Vec::with_capacity(out_size);
let mut out_multi = vec![0usize; out_shape.len()];
for _ in 0..out_size {
let mut count: u64 = 0;
for k in 0..axis_len {
let mut flat = 0usize;
let mut out_idx = 0usize;
for (i, &stride) in strides.iter().enumerate() {
if i == axis {
flat += stride * k;
} else {
flat += stride * out_multi[out_idx];
out_idx += 1;
}
}
if mask_data[flat] {
count += 1;
}
}
out_data.push(count);
for i in (0..out_shape.len()).rev() {
out_multi[i] += 1;
if out_multi[i] < out_shape[i] {
break;
}
out_multi[i] = 0;
}
}
let out_dim = if out_shape.is_empty() {
IxDyn::new(&[])
} else {
IxDyn::new(&out_shape)
};
ferray_core::Array::from_vec(out_dim, out_data)
}