use vortex_error::VortexResult;
use vortex_error::vortex_err;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::array::ArrayView;
use crate::array::VTable;
use crate::arrays::Bool;
use crate::arrays::scalar_fn::ExactScalarFn;
use crate::arrays::scalar_fn::ScalarFnArrayView;
use crate::kernel::ExecuteParentKernel;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::scalar_fn::fns::mask::Mask as MaskExpr;
pub trait MaskReduce: VTable {
fn mask(array: ArrayView<'_, Self>, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
}
pub trait MaskKernel: VTable {
fn mask(
array: ArrayView<'_, Self>,
mask: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>>;
}
#[derive(Default, Debug)]
pub struct MaskReduceAdaptor<V>(pub V);
impl<V> ArrayParentReduceRule<V> for MaskReduceAdaptor<V>
where
V: MaskReduce,
{
type Parent = ExactScalarFn<MaskExpr>;
fn reduce_parent(
&self,
array: ArrayView<'_, V>,
parent: ScalarFnArrayView<'_, MaskExpr>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
if child_idx != 0 {
return Ok(None);
}
let parent_ref: ArrayRef = (*parent).clone();
let mask_child = parent_ref
.nth_child(1)
.ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
if mask_child.as_opt::<Bool>().is_none() {
return Ok(None);
};
<V as MaskReduce>::mask(array, &mask_child)
}
}
#[derive(Default, Debug)]
pub struct MaskExecuteAdaptor<V>(pub V);
impl<V> ExecuteParentKernel<V> for MaskExecuteAdaptor<V>
where
V: MaskKernel,
{
type Parent = ExactScalarFn<MaskExpr>;
fn execute_parent(
&self,
array: ArrayView<'_, V>,
parent: ScalarFnArrayView<'_, MaskExpr>,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if child_idx != 0 {
return Ok(None);
}
let mask_child = parent
.nth_child(1)
.ok_or_else(|| vortex_err!("Mask expression must have 2 children"))?;
<V as MaskKernel>::mask(array, &mask_child, ctx)
}
}