use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::arrays::ScalarFnVTable;
use crate::arrays::scalar_fn::ExactScalarFn;
use crate::arrays::scalar_fn::ScalarFnArrayView;
use crate::kernel::ExecuteParentKernel;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::scalar_fn::fns::zip::Zip as ZipExpr;
use crate::vtable::VTable;
pub trait ZipReduce: VTable {
fn zip(
array: &Self::Array,
if_false: &ArrayRef,
mask: &ArrayRef,
) -> VortexResult<Option<ArrayRef>>;
}
pub trait ZipKernel: VTable {
fn zip(
array: &Self::Array,
if_false: &ArrayRef,
mask: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>>;
}
#[derive(Default, Debug)]
pub struct ZipReduceAdaptor<V>(pub V);
impl<V> ArrayParentReduceRule<V> for ZipReduceAdaptor<V>
where
V: ZipReduce,
{
type Parent = ExactScalarFn<ZipExpr>;
fn reduce_parent(
&self,
array: &V::Array,
parent: ScalarFnArrayView<'_, ZipExpr>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
if child_idx != 0 {
return Ok(None);
}
let scalar_fn_array = parent
.as_opt::<ScalarFnVTable>()
.vortex_expect("ExactScalarFn matcher confirmed ScalarFnArray");
let children = scalar_fn_array.children();
let if_false = &children[1];
let mask_array = &children[2];
<V as ZipReduce>::zip(array, if_false, mask_array)
}
}
#[derive(Default, Debug)]
pub struct ZipExecuteAdaptor<V>(pub V);
impl<V> ExecuteParentKernel<V> for ZipExecuteAdaptor<V>
where
V: ZipKernel,
{
type Parent = ExactScalarFn<ZipExpr>;
fn execute_parent(
&self,
array: &V::Array,
parent: ScalarFnArrayView<'_, ZipExpr>,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if child_idx != 0 {
return Ok(None);
}
let scalar_fn_array = parent
.as_opt::<ScalarFnVTable>()
.vortex_expect("ExactScalarFn matcher confirmed ScalarFnArray");
let children = scalar_fn_array.children();
let if_false = &children[1];
let mask_array = &children[2];
<V as ZipKernel>::zip(array, if_false, mask_array, ctx)
}
}