use vortex_error::VortexResult;
use crate::ArrayEq;
use crate::ArrayRef;
use crate::IntoArray;
use crate::Precision;
use crate::array::ArrayView;
use crate::arrays::Constant;
use crate::arrays::ConstantArray;
use crate::arrays::Dict;
use crate::arrays::DictArray;
use crate::arrays::ScalarFnArray;
use crate::arrays::ScalarFnVTable;
use crate::arrays::dict::DictArraySlotsExt;
use crate::arrays::filter::FilterReduceAdaptor;
use crate::arrays::scalar_fn::AnyScalarFn;
use crate::arrays::scalar_fn::ScalarFnArrayExt;
use crate::arrays::slice::SliceReduceAdaptor;
use crate::builtins::ArrayBuiltins;
use crate::optimizer::ArrayOptimizer;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::optimizer::rules::ParentRuleSet;
use crate::scalar_fn::fns::cast::Cast;
use crate::scalar_fn::fns::cast::CastReduceAdaptor;
use crate::scalar_fn::fns::like::LikeReduceAdaptor;
use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
use crate::scalar_fn::fns::pack::Pack;
pub(crate) const PARENT_RULES: ParentRuleSet<Dict> = ParentRuleSet::new(&[
ParentRuleSet::lift(&FilterReduceAdaptor(Dict)),
ParentRuleSet::lift(&CastReduceAdaptor(Dict)),
ParentRuleSet::lift(&MaskReduceAdaptor(Dict)),
ParentRuleSet::lift(&LikeReduceAdaptor(Dict)),
ParentRuleSet::lift(&DictionaryScalarFnValuesPushDownRule),
ParentRuleSet::lift(&DictionaryScalarFnCodesPullUpRule),
ParentRuleSet::lift(&SliceReduceAdaptor(Dict)),
]);
#[derive(Debug)]
struct DictionaryScalarFnValuesPushDownRule;
impl ArrayParentReduceRule<Dict> for DictionaryScalarFnValuesPushDownRule {
type Parent = AnyScalarFn;
fn reduce_parent(
&self,
array: ArrayView<'_, Dict>,
parent: ArrayView<'_, ScalarFnVTable>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let sig = parent.scalar_fn().signature();
if parent.scalar_fn().is::<Pack>() {
return Ok(None);
}
if parent.scalar_fn().is::<Cast>() {
return Ok(None);
}
if array.values().len() > array.codes().len() {
return Ok(None);
}
if !array.all_values_referenced && sig.is_fallible() {
tracing::trace!(
"Not pushing down fallible scalar function {} over dictionary with sparse codes {}",
parent.scalar_fn(),
Dict::ID,
);
return Ok(None);
}
if !parent
.iter_children()
.enumerate()
.all(|(idx, c)| idx == child_idx || c.is::<Constant>())
{
return Ok(None);
}
if array.codes().dtype().is_nullable()
&& !array.codes().all_valid()?
&& sig.is_null_sensitive()
{
tracing::trace!(
"Not pushing down null-sensitive scalar function {} over dictionary with null codes {}",
parent.scalar_fn(),
Dict::ID,
);
return Ok(None);
}
let values_len = array.values().len();
let mut new_children = Vec::with_capacity(parent.nchildren());
for (idx, child) in parent.iter_children().enumerate() {
if idx == child_idx {
new_children.push(array.values().clone());
} else {
let scalar = child.as_::<Constant>().scalar().clone();
new_children.push(ConstantArray::new(scalar, values_len).into_array());
}
}
let new_values =
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)?
.into_array()
.optimize()?;
if sig.is_null_sensitive() && array.codes().dtype().is_nullable() {
let new_codes = array.codes().cast(array.codes().dtype().as_nonnullable())?;
let new_dict = unsafe { DictArray::new_unchecked(new_codes, new_values) }.into_array();
return Ok(Some(new_dict.cast(parent.dtype().clone())?));
}
Ok(Some(
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array(),
))
}
}
#[derive(Debug)]
struct DictionaryScalarFnCodesPullUpRule;
impl ArrayParentReduceRule<Dict> for DictionaryScalarFnCodesPullUpRule {
type Parent = AnyScalarFn;
fn reduce_parent(
&self,
array: ArrayView<'_, Dict>,
parent: ArrayView<'_, ScalarFnVTable>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
if parent.nchildren() < 2 {
return Ok(None);
}
if !parent.iter_children().enumerate().all(|(idx, c)| {
idx == child_idx
|| c.as_opt::<Dict>()
.is_some_and(|c| c.values().len() == array.values().len())
}) {
return Ok(None);
}
if !parent.iter_children().enumerate().all(|(idx, c)| {
idx == child_idx
|| c.as_opt::<Dict>()
.is_some_and(|c| c.codes().array_eq(array.codes(), Precision::Value))
}) {
return Ok(None);
}
let mut new_children = Vec::with_capacity(parent.nchildren());
for (idx, child) in parent.iter_children().enumerate() {
if idx == child_idx {
new_children.push(array.values().clone());
} else {
new_children.push(child.as_::<Dict>().values().clone());
}
}
let new_values = ScalarFnArray::try_new(
parent.scalar_fn().clone(),
new_children,
array.values().len(),
)?
.into_array()
.optimize()?;
let new_dict =
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array();
Ok(Some(new_dict))
}
}