use ferray_core::dimension::Dimension;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::MaskedArray;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PutMode {
Raise,
Wrap,
Clip,
}
impl PutMode {
pub fn parse(s: &str) -> FerrayResult<Self> {
match s {
"raise" => Ok(Self::Raise),
"wrap" => Ok(Self::Wrap),
"clip" => Ok(Self::Clip),
other => Err(FerrayError::invalid_value(format!(
"put: clipmode '{other}' not understood (expected 'raise', 'wrap', or 'clip')"
))),
}
}
}
fn resolve_index(idx: isize, len: usize, mode: PutMode) -> FerrayResult<usize> {
let len_i = len as isize;
match mode {
PutMode::Raise => {
let resolved = if idx < 0 { idx + len_i } else { idx };
if resolved < 0 || resolved >= len_i {
return Err(FerrayError::index_out_of_bounds(idx, 0, len));
}
Ok(resolved as usize)
}
PutMode::Wrap => {
if len == 0 {
return Err(FerrayError::index_out_of_bounds(idx, 0, len));
}
Ok(idx.rem_euclid(len_i) as usize)
}
PutMode::Clip => {
if len == 0 {
return Err(FerrayError::index_out_of_bounds(idx, 0, len));
}
if idx < 0 {
Ok(0)
} else if idx >= len_i {
Ok(len - 1)
} else {
Ok(idx as usize)
}
}
}
}
impl<T: Element + Copy, D: Dimension> MaskedArray<T, D> {
pub fn put(
&mut self,
indices: &[isize],
values: &[T],
values_mask: Option<&[bool]>,
mode: PutMode,
) -> FerrayResult<()> {
if indices.is_empty() {
return Ok(());
}
let len = self.size();
let resolved: Vec<usize> = indices
.iter()
.map(|&i| resolve_index(i, len, mode))
.collect::<FerrayResult<_>>()?;
let hard = self.is_hard_mask() && self.has_real_mask();
let vmask_bit = |n: usize| -> bool {
match values_mask {
Some(vm) if !vm.is_empty() => vm[n % vm.len()],
_ => false,
}
};
if hard {
let original_mask: Vec<bool> = self.mask().iter().copied().collect();
for (n, &flat) in resolved.iter().enumerate() {
if original_mask[flat] {
continue;
}
let value = if n < values.len() {
values[n]
} else {
T::zero()
};
if let Some(slice) = self.data_mut() {
slice[flat] = value;
} else {
return Err(FerrayError::invalid_value(
"put: underlying data is not contiguous; cannot place values",
));
}
self.set_mask_flat(flat, vmask_bit(n))?;
}
return Ok(());
}
if !values.is_empty() {
for (n, &flat) in resolved.iter().enumerate() {
let value = values[n % values.len()];
if let Some(slice) = self.data_mut() {
slice[flat] = value;
} else {
return Err(FerrayError::invalid_value(
"put: underlying data is not contiguous; cannot place values",
));
}
}
}
for (n, &flat) in resolved.iter().enumerate() {
self.set_mask_flat(flat, vmask_bit(n))?;
}
Ok(())
}
pub fn putmask(
&mut self,
mask: &[bool],
values: &[T],
values_mask: Option<&[bool]>,
) -> FerrayResult<()> {
let len = self.size();
if mask.len() != 1 && mask.len() != len {
return Err(FerrayError::shape_mismatch(format!(
"putmask: boolean mask length {} does not broadcast to array size {len}",
mask.len()
)));
}
if values.is_empty() {
return Err(FerrayError::invalid_value(
"putmask: `values` must not be empty",
));
}
if values.len() != 1 && values.len() != len {
return Err(FerrayError::invalid_value(format!(
"putmask: could not broadcast values of length {} into array of size {len}",
values.len()
)));
}
let mask_bit = |i: usize| -> bool { if mask.len() == 1 { mask[0] } else { mask[i] } };
let broadcast = |i: usize| -> T {
if values.len() == 1 {
values[0]
} else {
values[i]
}
};
let vmask_bit = |i: usize| -> bool {
match values_mask {
None => false,
Some(vm) => {
if vm.len() == 1 {
vm[0]
} else {
vm[i]
}
}
}
};
let hard = self.is_hard_mask() && self.has_real_mask();
match (hard, values_mask) {
(false, _) => {
for i in 0..len {
if mask_bit(i) {
self.set_mask_flat(i, vmask_bit(i))?;
}
}
}
(true, Some(_)) => {
for i in 0..len {
if mask_bit(i) && vmask_bit(i) {
self.set_mask_flat(i, true)?;
}
}
}
(true, None) => {}
}
if let Some(slice) = self.data_mut() {
for (i, cell) in slice.iter_mut().enumerate() {
if mask_bit(i) {
*cell = broadcast(i);
}
}
} else {
return Err(FerrayError::invalid_value(
"putmask: underlying data is not contiguous; cannot place values",
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::Array;
use ferray_core::dimension::Ix1;
fn arr_f64(data: Vec<f64>) -> Array<f64, Ix1> {
let n = data.len();
Array::<f64, Ix1>::from_vec(Ix1::new([n]), data).unwrap()
}
fn arr_bool(data: Vec<bool>) -> Array<bool, Ix1> {
let n = data.len();
Array::<bool, Ix1>::from_vec(Ix1::new([n]), data).unwrap()
}
fn ma(data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix1> {
MaskedArray::new(arr_f64(data), arr_bool(mask)).unwrap()
}
fn data_of(m: &MaskedArray<f64, Ix1>) -> Vec<f64> {
m.data().iter().copied().collect()
}
fn mask_of(m: &MaskedArray<f64, Ix1>) -> Vec<bool> {
m.mask().iter().copied().collect()
}
#[test]
fn put_basic_unmasks_written() {
let mut a = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false, true, false, false]);
a.put(&[1, 3], &[10.0, 30.0], None, PutMode::Raise).unwrap();
assert_eq!(data_of(&a), vec![1.0, 10.0, 3.0, 30.0]);
assert_eq!(mask_of(&a), vec![false, false, false, false]);
}
#[test]
fn put_masked_values_propagate() {
let mut a = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
a.put(&[0, 1], &[10.0, 20.0], Some(&[true, false]), PutMode::Raise)
.unwrap();
assert_eq!(data_of(&a), vec![10.0, 20.0, 3.0, 4.0]);
assert_eq!(mask_of(&a), vec![true, false, false, false]);
}
#[test]
fn put_values_repeat() {
let mut a = ma(vec![1.0, 2.0, 3.0, 4.0, 5.0], vec![false; 5]);
a.put(&[0, 1, 2, 3], &[9.0], None, PutMode::Raise).unwrap();
assert_eq!(data_of(&a), vec![9.0, 9.0, 9.0, 9.0, 5.0]);
}
#[test]
fn put_mode_wrap() {
let mut a = ma(vec![1.0, 2.0, 3.0], vec![false; 3]);
a.put(&[5], &[99.0], None, PutMode::Wrap).unwrap();
assert_eq!(data_of(&a), vec![1.0, 2.0, 99.0]);
}
#[test]
fn put_mode_clip() {
let mut a = ma(vec![1.0, 2.0, 3.0], vec![false; 3]);
a.put(&[5], &[99.0], None, PutMode::Clip).unwrap();
assert_eq!(data_of(&a), vec![1.0, 2.0, 99.0]);
}
#[test]
fn put_mode_clip_negative() {
let mut a = ma(vec![1.0, 2.0, 3.0], vec![false; 3]);
a.put(&[-5], &[9.0], None, PutMode::Clip).unwrap();
assert_eq!(data_of(&a), vec![9.0, 2.0, 3.0]);
}
#[test]
fn put_mode_wrap_negative() {
let mut a = ma(vec![1.0, 2.0, 3.0], vec![false; 3]);
a.put(&[-5], &[9.0], None, PutMode::Wrap).unwrap();
assert_eq!(data_of(&a), vec![1.0, 9.0, 3.0]);
}
#[test]
fn put_negative_index_raise() {
let mut a = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
a.put(&[-1, -2], &[10.0, 20.0], None, PutMode::Raise)
.unwrap();
assert_eq!(data_of(&a), vec![1.0, 2.0, 20.0, 10.0]);
}
#[test]
fn put_mode_raise_out_of_bounds() {
let mut a = ma(vec![1.0, 2.0, 3.0], vec![false; 3]);
assert!(a.put(&[5], &[99.0], None, PutMode::Raise).is_err());
}
#[test]
fn put_mode_raise_negative_out_of_bounds() {
let mut a = ma(vec![1.0, 2.0, 3.0], vec![false; 3]);
assert!(a.put(&[-5], &[9.0], None, PutMode::Raise).is_err());
}
#[test]
fn put_hard_mask_suppresses_masked_target() {
let mut a = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false, true, false, false]);
a.harden_mask().unwrap();
a.put(&[1, 3], &[10.0, 30.0], None, PutMode::Raise).unwrap();
assert_eq!(data_of(&a), vec![1.0, 2.0, 3.0, 30.0]);
assert_eq!(mask_of(&a), vec![false, true, false, false]);
}
#[test]
fn put_empty_indices_noop() {
let mut a = ma(vec![1.0, 2.0, 3.0], vec![false; 3]);
a.put(&[], &[99.0], None, PutMode::Raise).unwrap();
assert_eq!(data_of(&a), vec![1.0, 2.0, 3.0]);
}
#[test]
fn put_empty_values_data_noop_unmasks_target() {
let mut a = ma(vec![1.0, 2.0, 3.0], vec![true, false, false]);
a.put(&[0], &[], None, PutMode::Raise).unwrap();
assert_eq!(data_of(&a), vec![1.0, 2.0, 3.0]);
assert_eq!(mask_of(&a), vec![false, false, false]);
}
#[test]
fn put_hard_mask_short_values_zeropad() {
let mut a = ma(
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![false, true, false, false, false],
);
a.harden_mask().unwrap();
a.put(&[0, 1, 2], &[10.0, 20.0], None, PutMode::Raise)
.unwrap();
assert_eq!(data_of(&a), vec![10.0, 2.0, 0.0, 4.0, 5.0]);
assert_eq!(mask_of(&a), vec![false, true, false, false, false]);
}
#[test]
fn put_hard_mask_zeropad_all_kept() {
let mut a = ma(
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![false, true, false, false, false],
);
a.harden_mask().unwrap();
a.put(&[0, 2, 4], &[10.0, 20.0], None, PutMode::Raise)
.unwrap();
assert_eq!(data_of(&a), vec![10.0, 2.0, 20.0, 4.0, 0.0]);
}
#[test]
fn put_hard_mask_long_values_truncate() {
let mut a = ma(
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![false, true, false, false, false],
);
a.harden_mask().unwrap();
a.put(
&[0, 1, 2],
&[10.0, 20.0, 30.0, 40.0, 50.0],
None,
PutMode::Raise,
)
.unwrap();
assert_eq!(data_of(&a), vec![10.0, 2.0, 30.0, 4.0, 5.0]);
}
#[test]
fn putmask_basic_clears_mask() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![true, false, false, false]);
x.putmask(&[true, false, true, false], &[10.0, 20.0, 30.0, 40.0], None)
.unwrap();
assert_eq!(data_of(&x), vec![10.0, 2.0, 30.0, 4.0]);
assert_eq!(mask_of(&x), vec![false, false, false, false]);
}
#[test]
fn putmask_scalar_broadcast() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
x.putmask(&[true, false, true, false], &[99.0], None)
.unwrap();
assert_eq!(data_of(&x), vec![99.0, 2.0, 99.0, 4.0]);
assert_eq!(mask_of(&x), vec![false, false, false, false]);
}
#[test]
fn putmask_mismatched_values_errors() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
assert!(
x.putmask(&[true, true, true, true], &[9.0, 8.0], None)
.is_err()
);
}
#[test]
fn putmask_mask_length_mismatch_errors() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
assert!(x.putmask(&[true, false], &[9.0], None).is_err());
}
#[test]
fn putmask_length1_mask_broadcasts() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
x.putmask(&[true], &[9.0], None).unwrap();
assert_eq!(data_of(&x), vec![9.0, 9.0, 9.0, 9.0]);
}
#[test]
fn putmask_length1_false_mask_noop() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
x.putmask(&[false], &[9.0], None).unwrap();
assert_eq!(data_of(&x), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn putmask_length1_mask_broadcasts_values_mask() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
x.putmask(&[true], &[9.0], Some(&[true])).unwrap();
assert_eq!(data_of(&x), vec![9.0, 9.0, 9.0, 9.0]);
assert_eq!(mask_of(&x), vec![true, true, true, true]);
}
#[test]
fn putmask_hard_keeps_mask_overwrites_data() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![true, false, false, false]);
x.harden_mask().unwrap();
x.putmask(&[true, true, false, false], &[10.0, 20.0, 30.0, 40.0], None)
.unwrap();
assert_eq!(data_of(&x), vec![10.0, 20.0, 3.0, 4.0]);
assert_eq!(mask_of(&x), vec![true, false, false, false]);
}
#[test]
fn putmask_soft_masked_values_propagate() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false, false, false, true]);
x.putmask(
&[true, true, false, false],
&[10.0, 20.0, 30.0, 40.0],
Some(&[true, false, false, false]),
)
.unwrap();
assert_eq!(data_of(&x), vec![10.0, 20.0, 3.0, 4.0]);
assert_eq!(mask_of(&x), vec![true, false, false, true]);
}
#[test]
fn putmask_hard_masked_values_union() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![false; 4]);
x.harden_mask().unwrap();
x.putmask(
&[true, true, false, false],
&[10.0, 20.0, 30.0, 40.0],
Some(&[true, false, false, false]),
)
.unwrap();
assert_eq!(data_of(&x), vec![10.0, 20.0, 3.0, 4.0]);
assert_eq!(mask_of(&x), vec![true, false, false, false]);
}
#[test]
fn putmask_hard_masked_target_data_overwritten() {
let mut x = ma(vec![1.0, 2.0, 3.0, 4.0], vec![true, false, false, false]);
x.harden_mask().unwrap();
x.putmask(&[true, false, false, false], &[10.0], None)
.unwrap();
assert_eq!(data_of(&x), vec![10.0, 2.0, 3.0, 4.0]);
assert_eq!(mask_of(&x), vec![true, false, false, false]);
}
#[test]
fn put_mode_parse() {
assert_eq!(PutMode::parse("raise").unwrap(), PutMode::Raise);
assert_eq!(PutMode::parse("wrap").unwrap(), PutMode::Wrap);
assert_eq!(PutMode::parse("clip").unwrap(), PutMode::Clip);
assert!(PutMode::parse("nonsense").is_err());
}
}