use vortex_dtype::match_each_native_ptype;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::MaskedArray;
use crate::arrays::MaskedVTable;
use crate::arrays::PrimitiveArray;
use crate::arrays::PrimitiveVTable;
use crate::arrays::SliceReduceAdaptor;
use crate::compute::MaskReduceAdaptor;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::optimizer::rules::ParentRuleSet;
use crate::vtable::ValidityHelper;
pub(crate) const RULES: ParentRuleSet<PrimitiveVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&PrimitiveMaskedValidityRule),
ParentRuleSet::lift(&MaskReduceAdaptor(PrimitiveVTable)),
ParentRuleSet::lift(&SliceReduceAdaptor(PrimitiveVTable)),
]);
#[derive(Default, Debug)]
pub struct PrimitiveMaskedValidityRule;
impl ArrayParentReduceRule<PrimitiveVTable> for PrimitiveMaskedValidityRule {
type Parent = MaskedVTable;
fn reduce_parent(
&self,
array: &PrimitiveArray,
parent: &MaskedArray,
_child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let masked_array = match_each_native_ptype!(array.ptype(), |T| {
unsafe {
PrimitiveArray::new_unchecked_from_handle(
array.buffer_handle().clone(),
array.ptype(),
array.validity().clone().and(parent.validity().clone())?,
)
}
.into_array()
});
Ok(Some(masked_array))
}
}