use vortex_error::VortexUnwrap;
use vortex_mask::Mask;
use crate::Array;
use crate::arrays::BoolArray;
use crate::compute::mask;
pub fn test_mask_conformance(array: &dyn Array) {
let len = array.len();
if len > 0 {
test_heterogenous_mask(array);
test_empty_mask(array);
test_full_mask(array);
test_alternating_mask(array);
test_sparse_mask(array);
test_single_element_mask(array);
}
if len >= 5 {
test_double_mask(array);
}
if len > 0 {
test_nullable_mask_input(array);
}
}
fn test_heterogenous_mask(array: &dyn Array) {
let len = array.len();
let mask_pattern: Vec<bool> = (0..len).map(|i| i % 3 != 1).collect();
let mask_array = Mask::from_iter(mask_pattern.clone());
let masked = mask(array, &mask_array).vortex_unwrap();
assert_eq!(masked.len(), array.len());
for (i, &masked_out) in mask_pattern.iter().enumerate() {
if masked_out {
assert!(!masked.is_valid(i));
} else {
assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
}
}
}
fn test_empty_mask(array: &dyn Array) {
let len = array.len();
let all_unmasked = vec![false; len];
let mask_array = Mask::from_iter(all_unmasked);
let masked = mask(array, &mask_array).vortex_unwrap();
assert_eq!(masked.len(), array.len());
for i in 0..len {
assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
}
}
fn test_full_mask(array: &dyn Array) {
let len = array.len();
let all_masked = vec![true; len];
let mask_array = Mask::from_iter(all_masked);
let masked = mask(array, &mask_array).vortex_unwrap();
assert_eq!(masked.len(), array.len());
for i in 0..len {
assert!(!masked.is_valid(i));
}
}
fn test_alternating_mask(array: &dyn Array) {
let len = array.len();
let pattern: Vec<bool> = (0..len).map(|i| i % 2 == 0).collect();
let mask_array = Mask::from_iter(pattern);
let masked = mask(array, &mask_array).vortex_unwrap();
assert_eq!(masked.len(), array.len());
for i in 0..len {
if i % 2 == 0 {
assert!(!masked.is_valid(i));
} else {
assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
}
}
}
fn test_sparse_mask(array: &dyn Array) {
let len = array.len();
if len < 10 {
return; }
let pattern: Vec<bool> = (0..len).map(|i| i % 10 == 0).collect();
let mask_array = Mask::from_iter(pattern.clone());
let masked = mask(array, &mask_array).vortex_unwrap();
assert_eq!(masked.len(), array.len());
let valid_count = (0..len).filter(|&i| masked.is_valid(i)).count();
let expected_invalid_count = (0..len)
.filter(|&i| pattern[i] || !array.is_valid(i))
.count();
assert_eq!(valid_count, len - expected_invalid_count);
}
fn test_single_element_mask(array: &dyn Array) {
let len = array.len();
let mut pattern = vec![false; len];
pattern[0] = true;
let mask_array = Mask::from_iter(pattern);
let masked = mask(array, &mask_array).vortex_unwrap();
assert!(!masked.is_valid(0));
for i in 1..len {
assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
}
}
fn test_double_mask(array: &dyn Array) {
let len = array.len();
let mask1_pattern: Vec<bool> = (0..len).map(|i| i % 3 == 0).collect();
let mask2_pattern: Vec<bool> = (0..len).map(|i| i % 2 == 0).collect();
let mask1 = Mask::from_iter(mask1_pattern.clone());
let mask2 = Mask::from_iter(mask2_pattern.clone());
let first_masked = mask(array, &mask1).vortex_unwrap();
let double_masked = mask(&first_masked, &mask2).vortex_unwrap();
for i in 0..len {
if mask1_pattern[i] || mask2_pattern[i] {
assert!(!double_masked.is_valid(i));
} else {
assert_eq!(
double_masked.scalar_at(i),
array.scalar_at(i).into_nullable()
);
}
}
}
fn test_nullable_mask_input(array: &dyn Array) {
let len = array.len();
if len < 3 {
return; }
let bool_values: Vec<bool> = (0..len).map(|i| i % 2 == 0).collect();
let validity_values: Vec<bool> = (0..len).map(|i| i % 3 != 0).collect();
let bool_array = BoolArray::from_iter(bool_values.clone());
let validity = crate::validity::Validity::from_iter(validity_values.clone());
let nullable_mask = BoolArray::from_bool_buffer(bool_array.boolean_buffer().clone(), validity);
let mask_array = nullable_mask.to_mask_fill_null_false();
let masked = mask(array, &mask_array).vortex_unwrap();
for i in 0..len {
if bool_values[i] && validity_values[i] {
assert!(!masked.is_valid(i));
} else {
assert_eq!(masked.scalar_at(i), array.scalar_at(i).into_nullable());
}
}
}