use vortex_error::VortexResult;
use super::Dict;
use crate::ArrayRef;
use crate::Canonical;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::array::VTable;
use crate::arrays::ConstantArray;
use crate::arrays::dict::DictArraySlotsExt;
use crate::expr::stats::Precision;
use crate::expr::stats::Stat;
use crate::expr::stats::StatsProvider;
use crate::expr::stats::StatsProviderExt;
use crate::kernel::ExecuteParentKernel;
use crate::matcher::Matcher;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::scalar::Scalar;
use crate::stats::StatsSet;
pub trait TakeReduce: VTable {
fn take(array: ArrayView<'_, Self>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
}
pub trait TakeExecute: VTable {
fn take(
array: ArrayView<'_, Self>,
indices: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>>;
}
fn precondition<V: VTable>(array: ArrayView<'_, V>, indices: &ArrayRef) -> Option<ArrayRef> {
if indices.is_empty() {
let result_dtype = array
.dtype()
.clone()
.union_nullability(indices.dtype().nullability());
return Some(Canonical::empty(&result_dtype).into_array());
}
if array.is_empty() {
return Some(
ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
.into_array(),
);
}
None
}
#[derive(Default, Debug)]
pub struct TakeReduceAdaptor<V>(pub V);
impl<V> ArrayParentReduceRule<V> for TakeReduceAdaptor<V>
where
V: TakeReduce,
{
type Parent = Dict;
fn reduce_parent(
&self,
array: ArrayView<'_, V>,
parent: ArrayView<'_, Dict>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
if child_idx != 1 {
return Ok(None);
}
if let Some(result) = precondition::<V>(array, parent.codes()) {
return Ok(Some(result));
}
let result = <V as TakeReduce>::take(array, parent.codes())?;
if let Some(ref taken) = result {
propagate_take_stats(array.array(), taken, parent.codes())?;
}
Ok(result)
}
}
#[derive(Default, Debug)]
pub struct TakeExecuteAdaptor<V>(pub V);
impl<V> ExecuteParentKernel<V> for TakeExecuteAdaptor<V>
where
V: TakeExecute,
{
type Parent = Dict;
fn execute_parent(
&self,
array: ArrayView<'_, V>,
parent: <Self::Parent as Matcher>::Match<'_>,
child_idx: usize,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if child_idx != 1 {
return Ok(None);
}
if let Some(result) = precondition::<V>(array, parent.codes()) {
return Ok(Some(result));
}
let result = <V as TakeExecute>::take(array, parent.codes(), ctx)?;
if let Some(ref taken) = result {
propagate_take_stats(array.array(), taken, parent.codes())?;
}
Ok(result)
}
}
pub(crate) fn propagate_take_stats(
source: &ArrayRef,
target: &ArrayRef,
indices: &ArrayRef,
) -> VortexResult<()> {
target.statistics().with_mut_typed_stats_set(|mut st| {
if indices.all_valid().unwrap_or(false) {
let is_constant = source.statistics().get_as::<bool>(Stat::IsConstant);
if is_constant == Some(Precision::Exact(true)) {
st.set(Stat::IsConstant, Precision::exact(true));
}
}
let inexact_min_max = [Stat::Min, Stat::Max]
.into_iter()
.filter_map(|stat| {
source
.statistics()
.get(stat)
.and_then(|v| v.map(|s| s.into_value()).into_inexact().transpose())
.map(|sv| (stat, sv))
})
.collect::<Vec<_>>();
st.combine_sets(
&(unsafe { StatsSet::new_unchecked(inexact_min_max) }).as_typed_ref(source.dtype()),
)
})
}