use pyo3::prelude::*;
#[pyclass(name = "MaskedArray")]
pub struct MaskedArray {
data: Vec<f64>,
mask: Vec<bool>,
shape: Vec<usize>,
fill_value: f64,
}
#[pymethods]
impl MaskedArray {
#[new]
pub fn new(
data: Vec<f64>,
mask: Option<Vec<bool>>,
shape: Vec<usize>,
fill_value: Option<f64>,
) -> PyResult<Self> {
let n: usize = shape.iter().product();
if data.len() != n {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"data length {} does not match shape product {}",
data.len(),
n
)));
}
let mask = mask.unwrap_or_else(|| vec![false; n]);
if mask.len() != n {
return Err(pyo3::exceptions::PyValueError::new_err(
"mask length does not match shape product",
));
}
Ok(Self {
data,
mask,
shape,
fill_value: fill_value.unwrap_or(f64::NAN),
})
}
pub fn filled(&self) -> Vec<f64> {
self.data
.iter()
.zip(self.mask.iter())
.map(|(&d, &m)| if m { self.fill_value } else { d })
.collect()
}
pub fn count(&self) -> usize {
self.mask.iter().filter(|&&m| !m).count()
}
pub fn mean(&self) -> Option<f64> {
let valid: Vec<f64> = self
.data
.iter()
.zip(self.mask.iter())
.filter(|(_, &m)| !m)
.map(|(&d, _)| d)
.collect();
if valid.is_empty() {
None
} else {
Some(valid.iter().sum::<f64>() / valid.len() as f64)
}
}
pub fn sum(&self) -> f64 {
self.data
.iter()
.zip(self.mask.iter())
.filter(|(_, &m)| !m)
.map(|(&d, _)| d)
.sum()
}
pub fn shape(&self) -> Vec<usize> {
self.shape.clone()
}
pub fn data(&self) -> Vec<f64> {
self.data.clone()
}
pub fn mask(&self) -> Vec<bool> {
self.mask.clone()
}
pub fn fill_value(&self) -> f64 {
self.fill_value
}
pub fn mask_element(&mut self, idx: usize, masked: bool) -> PyResult<()> {
if idx >= self.mask.len() {
return Err(pyo3::exceptions::PyIndexError::new_err(
"index out of bounds",
));
}
self.mask[idx] = masked;
Ok(())
}
pub fn apply_unmasked(&self, op: &str) -> PyResult<Vec<f64>> {
let fill = self.fill_value;
match op {
"abs" => Ok(self
.data
.iter()
.zip(self.mask.iter())
.map(|(&d, &m)| if m { fill } else { d.abs() })
.collect()),
"sqrt" => Ok(self
.data
.iter()
.zip(self.mask.iter())
.map(|(&d, &m)| if m { fill } else { d.sqrt() })
.collect()),
"log" => Ok(self
.data
.iter()
.zip(self.mask.iter())
.map(|(&d, &m)| if m { fill } else { d.ln() })
.collect()),
_ => Err(pyo3::exceptions::PyValueError::new_err(format!(
"unknown operation '{op}'; supported: abs, sqrt, log"
))),
}
}
}
#[pyfunction]
pub fn masked_array(data: Vec<f64>, mask: Vec<bool>) -> PyResult<MaskedArray> {
let n = data.len();
MaskedArray::new(data, Some(mask), vec![n], None)
}
#[pyfunction]
pub fn masked_less(data: Vec<f64>, threshold: f64) -> MaskedArray {
let n = data.len();
let mask: Vec<bool> = data.iter().map(|&d| d < threshold).collect();
MaskedArray {
data,
mask,
shape: vec![n],
fill_value: f64::NAN,
}
}
pub fn register_masked_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<MaskedArray>()?;
m.add_function(wrap_pyfunction!(masked_array, m)?)?;
m.add_function(wrap_pyfunction!(masked_less, m)?)?;
Ok(())
}