use std::fmt::Display;
use std::fmt::Formatter;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use crate::ArrayRef;
use crate::array::Array;
use crate::array::ArrayParts;
use crate::array::TypedArrayRef;
use crate::array::child_to_validity;
use crate::array::validity_to_child;
use crate::arrays::Masked;
use crate::validity::Validity;
pub(super) const CHILD_SLOT: usize = 0;
pub(super) const VALIDITY_SLOT: usize = 1;
pub(super) const NUM_SLOTS: usize = 2;
pub(super) const SLOT_NAMES: [&str; NUM_SLOTS] = ["child", "validity"];
#[derive(Clone, Debug)]
pub struct MaskedData;
impl Display for MaskedData {
fn fmt(&self, _f: &mut Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
pub trait MaskedArrayExt: TypedArrayRef<Masked> {
fn child(&self) -> &ArrayRef {
self.as_ref().slots()[CHILD_SLOT]
.as_ref()
.vortex_expect("validated masked child slot")
}
fn validity_child(&self) -> Option<&ArrayRef> {
self.as_ref().slots()[VALIDITY_SLOT].as_ref()
}
fn masked_validity(&self) -> Validity {
child_to_validity(
&self.as_ref().slots()[VALIDITY_SLOT],
self.as_ref().dtype().nullability(),
)
}
fn masked_validity_mask(&self) -> vortex_mask::Mask {
self.masked_validity().to_mask(self.as_ref().len())
}
}
impl<T: TypedArrayRef<Masked>> MaskedArrayExt for T {}
impl MaskedData {
pub(crate) fn try_new(
child_len: usize,
child_all_valid: bool,
validity: Validity,
) -> VortexResult<Self> {
if matches!(validity, Validity::NonNullable) {
vortex_bail!("MaskedArray must have nullable validity, got {validity:?}")
}
if !child_all_valid {
vortex_bail!("MaskedArray children must not have nulls");
}
if let Some(validity_len) = validity.maybe_len()
&& validity_len != child_len
{
vortex_bail!("Validity must be the same length as a MaskedArray's child");
}
Ok(Self)
}
}
impl Array<Masked> {
pub fn try_new(child: ArrayRef, validity: Validity) -> VortexResult<Self> {
let dtype = child.dtype().as_nullable();
let len = child.len();
let validity_slot = validity_to_child(&validity, len);
let data = MaskedData::try_new(len, child.all_valid()?, validity)?;
Ok(unsafe {
Array::from_parts_unchecked(
ArrayParts::new(Masked, dtype, len, data)
.with_slots(vec![Some(child), validity_slot]),
)
})
}
}