use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::scalar_fn::ExactScalarFn;
use crate::arrays::scalar_fn::ScalarFnArrayView;
use crate::dtype::DType;
use crate::kernel::ExecuteParentKernel;
use crate::matcher::Matcher;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::scalar_fn::fns::cast::Cast;
use crate::vtable::VTable;
pub trait CastReduce: VTable {
fn cast(array: &Self::Array, dtype: &DType) -> VortexResult<Option<ArrayRef>>;
}
pub trait CastKernel: VTable {
fn cast(
array: &Self::Array,
dtype: &DType,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>>;
}
#[derive(Default, Debug)]
pub struct CastReduceAdaptor<V>(pub V);
impl<V> ArrayParentReduceRule<V> for CastReduceAdaptor<V>
where
V: CastReduce,
{
type Parent = ExactScalarFn<Cast>;
fn reduce_parent(
&self,
array: &V::Array,
parent: ScalarFnArrayView<'_, Cast>,
_child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let dtype = parent.options;
if array.dtype() == dtype {
return Ok(Some(array.clone().into_array()));
}
<V as CastReduce>::cast(array, dtype)
}
}
#[derive(Default, Debug)]
pub struct CastExecuteAdaptor<V>(pub V);
impl<V> ExecuteParentKernel<V> for CastExecuteAdaptor<V>
where
V: CastKernel,
{
type Parent = ExactScalarFn<Cast>;
fn execute_parent(
&self,
array: &V::Array,
parent: <Self::Parent as Matcher>::Match<'_>,
_child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let dtype = parent.options;
if array.dtype() == dtype {
return Ok(Some(array.clone().into_array()));
}
<V as CastKernel>::cast(array, dtype, ctx)
}
}