use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::Masked;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::arrays::slice::SliceReduceAdaptor;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::optimizer::rules::ParentRuleSet;
use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
pub(crate) const RULES: ParentRuleSet<Primitive> = ParentRuleSet::new(&[
ParentRuleSet::lift(&PrimitiveMaskedValidityRule),
ParentRuleSet::lift(&MaskReduceAdaptor(Primitive)),
ParentRuleSet::lift(&SliceReduceAdaptor(Primitive)),
]);
#[derive(Default, Debug)]
pub struct PrimitiveMaskedValidityRule;
impl ArrayParentReduceRule<Primitive> for PrimitiveMaskedValidityRule {
type Parent = Masked;
fn reduce_parent(
&self,
array: ArrayView<'_, Primitive>,
parent: ArrayView<'_, Masked>,
_child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let new_validity = array.validity()?.and(parent.validity()?)?;
let masked_array = unsafe {
PrimitiveArray::new_unchecked_from_handle(
array.buffer_handle().clone(),
array.ptype(),
new_validity,
)
};
Ok(Some(masked_array.into_array()))
}
}