use rstest::rstest;
use vortex_dtype::{DType, Nullability};
use super::*;
use crate::arrays::PrimitiveArray;
use crate::validity::Validity;
use crate::{Array, IntoArray, ToCanonical as _};
#[rstest]
#[case(Validity::AllValid, Nullability::Nullable)]
#[case(Validity::from_iter([true, false, true]), Nullability::Nullable)]
fn test_dtype_nullability(#[case] validity: Validity, #[case] expected: Nullability) {
let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
let array = MaskedArray::try_new(child, validity).unwrap();
assert_eq!(
array.dtype(),
&DType::Primitive(vortex_dtype::PType::I32, expected)
);
}
#[test]
fn test_dtype_nullability_with_nullable_child() {
let child =
PrimitiveArray::new(vortex_buffer::buffer![1i32, 2, 3], Validity::AllValid).into_array();
assert!(child.dtype().is_nullable());
}
#[test]
fn test_canonical_dtype_matches_array_dtype() {
let child = PrimitiveArray::from_iter([1i32, 2, 3]).into_array();
let array = MaskedArray::try_new(child, Validity::AllValid).unwrap();
let canonical = array.to_canonical();
assert_eq!(canonical.as_ref().dtype(), array.dtype());
}
#[test]
fn test_masked_child_with_validity() {
let child = PrimitiveArray::from_iter([1i32, 2, 3, 4, 5]).into_array();
let array =
MaskedArray::try_new(child, Validity::from_iter([true, false, true, false, true])).unwrap();
let masked = array.masked_child().unwrap();
let prim = masked.to_primitive();
assert_eq!(prim.valid_count(), 3);
assert!(prim.is_valid(0));
assert!(!prim.is_valid(1));
assert!(prim.is_valid(2));
assert!(!prim.is_valid(3));
assert!(prim.is_valid(4));
assert_eq!(
array.as_ref().display_values().to_string(),
masked.display_values().to_string()
);
}
#[test]
fn test_masked_child_all_valid() {
let child = PrimitiveArray::from_iter([10i32, 20, 30]).into_array();
let array = MaskedArray::try_new(child, Validity::AllValid).unwrap();
let masked = array.masked_child().unwrap();
assert_eq!(masked.len(), 3);
assert_eq!(masked.valid_count(), 3);
assert_eq!(
array.as_ref().display_values().to_string(),
masked.display_values().to_string()
);
}
#[rstest]
#[case(Validity::AllValid)]
#[case(Validity::from_iter([true, true, true]))]
#[case(Validity::from_iter([false, false, false]))]
#[case(Validity::from_iter([true, false, true, false]))]
fn test_masked_child_preserves_length(#[case] validity: Validity) {
let len = match &validity {
Validity::Array(arr) => arr.len(),
_ => 3,
};
#[allow(clippy::cast_possible_truncation)]
let child = PrimitiveArray::from_iter(0..len as i32).into_array();
let array = MaskedArray::try_new(child, validity.clone()).unwrap();
let masked = array.masked_child().unwrap();
assert_eq!(masked.len(), len);
assert_eq!(masked.validity_mask(), validity.to_mask(len));
assert_eq!(
array.as_ref().display_values().to_string(),
masked.display_values().to_string()
);
}