use ferray_core::{Array as FerrayArray, IxDyn as FerrayIxDyn};
use ferray_ma::masked_array::MaskedArray;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
#[derive(Clone, Debug)]
pub struct MaskedTensor<T: Float> {
data: Tensor<T>,
mask: Vec<bool>,
fill_value: T,
}
impl<T: Float> MaskedTensor<T> {
pub fn new(data: Tensor<T>, mask: Vec<bool>) -> FerrotorchResult<Self> {
if mask.len() != data.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"MaskedTensor::new: mask length {} != data numel {}",
mask.len(),
data.numel()
),
});
}
Ok(Self {
data,
mask,
fill_value: <T as num_traits::Zero>::zero(),
})
}
pub fn from_data(data: Tensor<T>) -> FerrotorchResult<Self> {
let n = data.numel();
Self::new(data, vec![true; n])
}
pub fn with_fill_value(mut self, fill_value: T) -> Self {
self.fill_value = fill_value;
self
}
#[inline]
pub fn data(&self) -> &Tensor<T> {
&self.data
}
#[inline]
pub fn mask(&self) -> &[bool] {
&self.mask
}
#[inline]
pub fn fill_value(&self) -> T {
self.fill_value
}
#[inline]
pub fn shape(&self) -> &[usize] {
self.data.shape()
}
#[inline]
pub fn numel(&self) -> usize {
self.data.numel()
}
pub fn count_valid(&self) -> usize {
self.mask.iter().filter(|&&v| v).count()
}
pub fn count_masked(&self) -> usize {
self.mask.iter().filter(|&&v| !v).count()
}
pub fn filled(&self) -> FerrotorchResult<Tensor<T>> {
let data_vec = self.data.data_vec()?;
let out: Vec<T> = data_vec
.iter()
.zip(self.mask.iter())
.map(|(&v, &valid)| if valid { v } else { self.fill_value })
.collect();
Tensor::from_storage(TensorStorage::cpu(out), self.data.shape().to_vec(), false)
}
#[inline]
pub fn to_tensor(&self) -> FerrotorchResult<Tensor<T>> {
self.filled()
}
}
impl<T: Float> MaskedTensor<T> {
pub fn to_ferray<U>(&self, op: &'static str) -> FerrotorchResult<MaskedArray<U, FerrayIxDyn>>
where
U: ferray_core::Element + Copy + num_traits::Float + 'static,
{
let data_vec = self.data.data_vec()?;
let data_u: Vec<U> = data_vec
.into_iter()
.map(|v| U::from(v.to_f64().unwrap()).unwrap())
.collect();
let arr =
FerrayArray::<U, FerrayIxDyn>::from_vec(FerrayIxDyn::new(self.data.shape()), data_u)
.map_err(FerrotorchError::Ferray)?;
let inv: Vec<bool> = self.mask.iter().map(|&v| !v).collect();
let mask_arr =
FerrayArray::<bool, FerrayIxDyn>::from_vec(FerrayIxDyn::new(self.data.shape()), inv)
.map_err(FerrotorchError::Ferray)?;
MaskedArray::new(arr, mask_arr).map_err(|e| FerrotorchError::InvalidArgument {
message: format!("{op}: {e}"),
})
}
}
pub fn masked_sum<T: Float>(mt: &MaskedTensor<T>) -> FerrotorchResult<Tensor<T>> {
if mt.data.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
return masked_sum_gpu(mt);
}
if mt.data.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "masked_sum" });
}
let data = mt.data.data_vec()?;
let mut acc = <T as num_traits::Zero>::zero();
for (&v, &valid) in data.iter().zip(mt.mask.iter()) {
if valid {
acc += v;
}
}
Tensor::from_storage(TensorStorage::cpu(vec![acc]), vec![], false)
}
fn masked_sum_gpu<T: Float>(mt: &MaskedTensor<T>) -> FerrotorchResult<Tensor<T>> {
let device = mt.data.device();
let mask_t: Tensor<T> = mask_as_float_tensor(&mt.mask, mt.data.shape(), device)?;
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let numel = mt.data.numel();
let data = mt.data.contiguous()?;
let prod_h = if is_f32::<T>() {
backend.mul_f32(data.gpu_handle()?, mask_t.gpu_handle()?)?
} else {
backend.mul_f64(data.gpu_handle()?, mask_t.gpu_handle()?)?
};
let sum_h = if is_f32::<T>() {
backend.sum_f32(&prod_h, numel)?
} else {
backend.sum_f64(&prod_h, numel)?
};
Tensor::from_storage(TensorStorage::gpu(sum_h), vec![], false)
}
fn mask_as_float_tensor<T: Float>(
mask: &[bool],
shape: &[usize],
device: crate::device::Device,
) -> FerrotorchResult<Tensor<T>> {
let one = T::from(1.0).unwrap();
let zero = <T as num_traits::Zero>::zero();
let data: Vec<T> = mask.iter().map(|&b| if b { one } else { zero }).collect();
let cpu = Tensor::from_storage(TensorStorage::cpu(data), shape.to_vec(), false)?;
if device.is_cuda() {
cpu.to(device)
} else {
Ok(cpu)
}
}
#[inline]
fn is_f32<T: Float>() -> bool {
std::mem::size_of::<T>() == 4
}
#[inline]
fn is_f64<T: Float>() -> bool {
std::mem::size_of::<T>() == 8
}
pub fn masked_mean<T: Float>(mt: &MaskedTensor<T>) -> FerrotorchResult<Tensor<T>> {
if mt.data.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
return masked_mean_gpu(mt);
}
if mt.data.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "masked_mean" });
}
let data = mt.data.data_vec()?;
let mut acc = <T as num_traits::Zero>::zero();
let mut count: usize = 0;
for (&v, &valid) in data.iter().zip(mt.mask.iter()) {
if valid {
acc += v;
count += 1;
}
}
let val = if count == 0 {
T::from(f64::NAN).unwrap()
} else {
acc / T::from(count as f64).unwrap()
};
Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], false)
}
fn masked_mean_gpu<T: Float>(mt: &MaskedTensor<T>) -> FerrotorchResult<Tensor<T>> {
let count = mt.count_valid();
if count == 0 {
let nan = T::from(f64::NAN).unwrap();
return Tensor::from_storage(TensorStorage::cpu(vec![nan]), vec![], false);
}
let sum = masked_sum_gpu(mt)?;
let sum_val = sum.cpu()?.data()?[0];
let mean = sum_val / T::from(count as f64).unwrap();
Tensor::from_storage(TensorStorage::cpu(vec![mean]), vec![], false)
}
pub fn masked_min<T: Float>(mt: &MaskedTensor<T>) -> FerrotorchResult<Tensor<T>> {
if mt.data.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
return masked_extremum_gpu(mt, true);
}
masked_extremum_cpu(mt, true)
}
pub fn masked_max<T: Float>(mt: &MaskedTensor<T>) -> FerrotorchResult<Tensor<T>> {
if mt.data.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
return masked_extremum_gpu(mt, false);
}
masked_extremum_cpu(mt, false)
}
fn masked_extremum_cpu<T: Float>(
mt: &MaskedTensor<T>,
pick_min: bool,
) -> FerrotorchResult<Tensor<T>> {
let device = mt.data.device();
let data = mt.data.data_vec()?;
let mut best: Option<T> = None;
for (&v, &valid) in data.iter().zip(mt.mask.iter()) {
if !valid {
continue;
}
best = Some(match best {
None => v,
Some(b) if pick_min => {
if v < b {
v
} else {
b
}
}
Some(b) => {
if v > b {
v
} else {
b
}
}
});
}
let val = best.unwrap_or_else(|| T::from(f64::NAN).unwrap());
let cpu = Tensor::from_storage(TensorStorage::cpu(vec![val]), vec![], false)?;
if device.is_cuda() {
cpu.to(device)
} else {
Ok(cpu)
}
}
fn masked_extremum_gpu<T: Float>(
mt: &MaskedTensor<T>,
pick_min: bool,
) -> FerrotorchResult<Tensor<T>> {
if mt.count_valid() == 0 {
let nan = T::from(f64::NAN).unwrap();
return Tensor::from_storage(TensorStorage::cpu(vec![nan]), vec![], false);
}
let device = mt.data.device();
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let numel = mt.data.numel();
let mask_t: Tensor<T> = mask_as_float_tensor(&mt.mask, mt.data.shape(), device)?;
let data = mt.data.contiguous()?;
let result_h = if pick_min {
if is_f32::<T>() {
backend.masked_min_f32(data.gpu_handle()?, mask_t.gpu_handle()?, numel)?
} else {
backend.masked_min_f64(data.gpu_handle()?, mask_t.gpu_handle()?, numel)?
}
} else if is_f32::<T>() {
backend.masked_max_f32(data.gpu_handle()?, mask_t.gpu_handle()?, numel)?
} else {
backend.masked_max_f64(data.gpu_handle()?, mask_t.gpu_handle()?, numel)?
};
Tensor::from_storage(TensorStorage::gpu(result_h), vec![], false)
}
pub fn masked_count<T: Float>(mt: &MaskedTensor<T>) -> FerrotorchResult<Tensor<T>> {
let n = mt.count_valid();
Tensor::from_storage(
TensorStorage::cpu(vec![T::from(n as f64).unwrap()]),
vec![],
false,
)
}
pub fn masked_where<T: Float>(
data: Tensor<T>,
condition: &[bool],
) -> FerrotorchResult<MaskedTensor<T>> {
if condition.len() != data.numel() {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"masked_where: condition length {} != data numel {}",
condition.len(),
data.numel()
),
});
}
let mask: Vec<bool> = condition.iter().map(|&c| !c).collect();
MaskedTensor::new(data, mask)
}
pub fn masked_invalid<T: Float>(data: Tensor<T>) -> FerrotorchResult<MaskedTensor<T>> {
if data.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let data_c = data.contiguous()?;
let mask_h = backend.isfinite_mask(data_c.gpu_handle()?)?;
let mask = predicate_mask_gpu(backend, &mask_h, data.numel())?;
return MaskedTensor::new(data, mask);
}
if data.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "masked_invalid",
});
}
let data_vec = data.data_vec()?;
let mask: Vec<bool> = data_vec
.iter()
.map(|v| {
let f = v.to_f64().unwrap();
f.is_finite()
})
.collect();
MaskedTensor::new(data, mask)
}
pub fn masked_equal<T: Float + PartialEq>(
data: Tensor<T>,
value: T,
) -> FerrotorchResult<MaskedTensor<T>> {
if data.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let value_f = value
.to_f64()
.ok_or_else(|| FerrotorchError::InvalidArgument {
message: "masked_equal: value not representable as f64".into(),
})?;
let data_c = data.contiguous()?;
let mask_h = backend.ne_scalar_mask(data_c.gpu_handle()?, value_f)?;
let mask = predicate_mask_gpu(backend, &mask_h, data.numel())?;
return MaskedTensor::new(data, mask);
}
if data.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "masked_equal" });
}
let data_vec = data.data_vec()?;
let mask: Vec<bool> = data_vec.iter().map(|&v| v != value).collect();
MaskedTensor::new(data, mask)
}
fn predicate_mask_gpu(
backend: &dyn crate::gpu_dispatch::GpuBackend,
mask_h: &crate::gpu_dispatch::GpuBufferHandle,
numel: usize,
) -> FerrotorchResult<Vec<bool>> {
let bytes = backend.gpu_to_cpu(mask_h)?;
Ok(bytes.iter().take(numel).map(|&b| b != 0).collect())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creation::tensor;
fn t(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn close(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn new_with_matching_mask() {
let d = t(&[1.0, 2.0, 3.0], &[3]);
let m = MaskedTensor::new(d, vec![true, false, true]).unwrap();
assert_eq!(m.shape(), &[3]);
assert_eq!(m.numel(), 3);
assert_eq!(m.count_valid(), 2);
assert_eq!(m.count_masked(), 1);
}
#[test]
fn new_rejects_mask_length_mismatch() {
let d = t(&[1.0, 2.0, 3.0], &[3]);
let err = MaskedTensor::new(d, vec![true, false]).unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn from_data_marks_all_valid() {
let d = t(&[1.0, 2.0, 3.0], &[3]);
let m = MaskedTensor::from_data(d).unwrap();
assert_eq!(m.count_valid(), 3);
assert_eq!(m.count_masked(), 0);
}
#[test]
fn masked_where_inverts_condition() {
let d = t(&[10.0, 20.0, 30.0, 40.0], &[4]);
let mt = masked_where(d, &[false, true, false, true]).unwrap();
assert_eq!(mt.mask(), &[true, false, true, false]);
assert_eq!(mt.count_valid(), 2);
}
#[test]
fn masked_invalid_masks_nan() {
let d = t(&[1.0, f64::NAN, 3.0, f64::INFINITY], &[4]);
let mt = masked_invalid(d).unwrap();
assert_eq!(mt.mask(), &[true, false, true, false]);
}
#[test]
fn masked_equal_masks_matching() {
let d = t(&[1.0, 5.0, 5.0, 2.0], &[4]);
let mt = masked_equal(d, 5.0).unwrap();
assert_eq!(mt.mask(), &[true, false, false, true]);
}
#[test]
fn masked_sum_skips_masked_entries() {
let d = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5]);
let mt = MaskedTensor::new(d, vec![true, false, true, false, true]).unwrap();
let s = masked_sum(&mt).unwrap();
assert!(close(s.data().unwrap()[0], 9.0, 1e-12));
}
#[test]
fn masked_mean_divides_by_valid_count() {
let d = t(&[10.0, 0.0, 30.0, 0.0, 50.0], &[5]);
let mt = MaskedTensor::new(d, vec![true, false, true, false, true]).unwrap();
let r = masked_mean(&mt).unwrap();
assert!(close(r.data().unwrap()[0], 30.0, 1e-12));
}
#[test]
fn masked_mean_all_masked_returns_nan() {
let d = t(&[1.0, 2.0, 3.0], &[3]);
let mt = MaskedTensor::new(d, vec![false, false, false]).unwrap();
let r = masked_mean(&mt).unwrap();
assert!(r.data().unwrap()[0].is_nan());
}
#[test]
fn masked_min_max_skip_masked() {
let d = t(&[5.0, 1.0, 9.0, 2.0], &[4]);
let mt = MaskedTensor::new(d, vec![true, false, false, true]).unwrap();
assert!(close(
masked_min(&mt).unwrap().data().unwrap()[0],
2.0,
1e-12
));
assert!(close(
masked_max(&mt).unwrap().data().unwrap()[0],
5.0,
1e-12
));
}
#[test]
#[allow(clippy::float_cmp)]
fn masked_count_returns_valid_count() {
let d = t(&[1.0, 2.0, 3.0, 4.0], &[4]);
let mt = MaskedTensor::new(d, vec![true, false, true, true]).unwrap();
let c = masked_count(&mt).unwrap();
assert_eq!(c.data().unwrap()[0], 3.0);
}
#[test]
fn filled_substitutes_default_zero() {
let d = t(&[1.0, 2.0, 3.0], &[3]);
let mt = MaskedTensor::new(d, vec![true, false, true]).unwrap();
let f = mt.filled().unwrap();
assert_eq!(f.data().unwrap(), &[1.0, 0.0, 3.0]);
}
#[test]
fn filled_uses_fill_value() {
let d = t(&[1.0, 2.0, 3.0], &[3]);
let mt = MaskedTensor::new(d, vec![true, false, true])
.unwrap()
.with_fill_value(-99.0);
let f = mt.filled().unwrap();
assert_eq!(f.data().unwrap(), &[1.0, -99.0, 3.0]);
}
#[test]
fn to_tensor_is_alias_for_filled() {
let d = t(&[1.0, 2.0, 3.0], &[3]);
let mt = MaskedTensor::new(d, vec![true, false, true]).unwrap();
let a = mt.filled().unwrap();
let b = mt.to_tensor().unwrap();
assert_eq!(a.data().unwrap(), b.data().unwrap());
}
#[test]
fn to_ferray_round_trip_mean_matches_inhouse() {
let d = t(&[2.0, 4.0, 6.0, 8.0], &[4]);
let mt = MaskedTensor::new(d, vec![true, false, true, false]).unwrap();
let inhouse = masked_mean(&mt).unwrap().data().unwrap()[0];
let ferray_ma_view: MaskedArray<f64, FerrayIxDyn> = mt.to_ferray("test").unwrap();
let ferray_mean = ferray_ma_view.mean().unwrap();
assert!(close(inhouse, ferray_mean, 1e-12));
assert!(close(inhouse, 4.0, 1e-12));
}
#[test]
fn constructors_accept_cpu_tensors() {
let d = tensor(&[1.0_f64, 2.0, 3.0]).unwrap();
assert!(MaskedTensor::from_data(d.clone()).is_ok());
assert!(masked_where(d.clone(), &[false, true, false]).is_ok());
assert!(masked_invalid(d.clone()).is_ok());
assert!(masked_equal(d, 2.0).is_ok());
}
#[test]
fn masked_min_max_match_cpu_definition() {
let d = tensor(&[1.0_f64, -3.0, 5.0, 7.0]).unwrap();
let mt = MaskedTensor::new(d, vec![true, false, true, false]).unwrap();
assert_eq!(masked_min(&mt).unwrap().data().unwrap(), &[1.0]);
assert_eq!(masked_max(&mt).unwrap().data().unwrap(), &[5.0]);
}
#[test]
fn masked_min_max_all_masked_returns_nan() {
let d = tensor(&[1.0_f64, 2.0]).unwrap();
let mt = MaskedTensor::new(d, vec![false, false]).unwrap();
assert!(masked_min(&mt).unwrap().data().unwrap()[0].is_nan());
assert!(masked_max(&mt).unwrap().data().unwrap()[0].is_nan());
}
}