use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_ensure;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::array::VTable;
use crate::arrays::Constant;
use crate::arrays::ConstantArray;
use crate::arrays::ScalarFnVTable;
use crate::arrays::scalar_fn::ExactScalarFn;
use crate::arrays::scalar_fn::ScalarFnArrayExt;
use crate::arrays::scalar_fn::ScalarFnArrayView;
use crate::builtins::ArrayBuiltins;
use crate::kernel::ExecuteParentKernel;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::scalar::Scalar;
use crate::scalar_fn::fns::fill_null::FillNull as FillNullExpr;
pub trait FillNullReduce: VTable {
fn fill_null(array: ArrayView<'_, Self>, fill_value: &Scalar)
-> VortexResult<Option<ArrayRef>>;
}
pub trait FillNullKernel: VTable {
fn fill_null(
array: ArrayView<'_, Self>,
fill_value: &Scalar,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>>;
}
pub(super) fn precondition(
array: &ArrayRef,
fill_value: &Scalar,
) -> VortexResult<Option<ArrayRef>> {
vortex_ensure!(
!fill_value.is_null(),
"fill_null requires a non-null fill value"
);
if !array.dtype().is_nullable() || array.all_valid()? {
return array.clone().cast(fill_value.dtype().clone()).map(Some);
}
if array.all_invalid()? {
return Ok(Some(
ConstantArray::new(fill_value.clone(), array.len()).into_array(),
));
}
Ok(None)
}
pub(crate) fn fill_null_constant(
array: ArrayView<Constant>,
fill_value: &Scalar,
) -> VortexResult<ArrayRef> {
let scalar = if array.scalar().is_null() {
fill_value.clone()
} else {
array.scalar().cast(fill_value.dtype())?
};
Ok(ConstantArray::new(scalar, array.len()).into_array())
}
#[derive(Default, Debug)]
pub struct FillNullReduceAdaptor<V>(pub V);
impl<V> ArrayParentReduceRule<V> for FillNullReduceAdaptor<V>
where
V: FillNullReduce,
{
type Parent = ExactScalarFn<FillNullExpr>;
fn reduce_parent(
&self,
array: ArrayView<'_, V>,
parent: ScalarFnArrayView<'_, FillNullExpr>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
if child_idx != 0 {
return Ok(None);
}
let scalar_fn_array = parent
.as_opt::<ScalarFnVTable>()
.vortex_expect("ExactScalarFn matcher confirmed ScalarFnArray");
let fill_value = scalar_fn_array
.get_child(1)
.as_constant()
.vortex_expect("fill_null fill_value must be constant");
let arr = array.array().clone();
if let Some(result) = precondition(&arr, &fill_value)? {
return Ok(Some(result));
}
<V as FillNullReduce>::fill_null(array, &fill_value)
}
}
#[derive(Default, Debug)]
pub struct FillNullExecuteAdaptor<V>(pub V);
impl<V> ExecuteParentKernel<V> for FillNullExecuteAdaptor<V>
where
V: FillNullKernel,
{
type Parent = ExactScalarFn<FillNullExpr>;
fn execute_parent(
&self,
array: ArrayView<'_, V>,
parent: ScalarFnArrayView<'_, FillNullExpr>,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if child_idx != 0 {
return Ok(None);
}
let scalar_fn_array = parent
.as_opt::<ScalarFnVTable>()
.vortex_expect("ExactScalarFn matcher confirmed ScalarFnArray");
let fill_value = scalar_fn_array
.get_child(1)
.as_constant()
.vortex_expect("fill_null fill_value must be constant");
let arr = array.array().clone();
if let Some(result) = precondition(&arr, &fill_value)? {
return Ok(Some(result));
}
<V as FillNullKernel>::fill_null(array, &fill_value, ctx)
}
}