use vortex_error::VortexResult;
use crate::ArrayEq;
use crate::ArrayRef;
use crate::DynArray;
use crate::IntoArray;
use crate::Precision;
use crate::arrays::ConstantArray;
use crate::arrays::ConstantVTable;
use crate::arrays::DictArray;
use crate::arrays::DictVTable;
use crate::arrays::ScalarFnArray;
use crate::arrays::filter::FilterReduceAdaptor;
use crate::arrays::scalar_fn::AnyScalarFn;
use crate::arrays::slice::SliceReduceAdaptor;
use crate::builtins::ArrayBuiltins;
use crate::optimizer::ArrayOptimizer;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::optimizer::rules::ParentRuleSet;
use crate::scalar_fn::fns::cast::Cast;
use crate::scalar_fn::fns::cast::CastReduceAdaptor;
use crate::scalar_fn::fns::like::LikeReduceAdaptor;
use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
use crate::scalar_fn::fns::pack::Pack;
pub(crate) const PARENT_RULES: ParentRuleSet<DictVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&FilterReduceAdaptor(DictVTable)),
ParentRuleSet::lift(&CastReduceAdaptor(DictVTable)),
ParentRuleSet::lift(&MaskReduceAdaptor(DictVTable)),
ParentRuleSet::lift(&LikeReduceAdaptor(DictVTable)),
ParentRuleSet::lift(&DictionaryScalarFnValuesPushDownRule),
ParentRuleSet::lift(&DictionaryScalarFnCodesPullUpRule),
ParentRuleSet::lift(&SliceReduceAdaptor(DictVTable)),
]);
#[derive(Debug)]
struct DictionaryScalarFnValuesPushDownRule;
impl ArrayParentReduceRule<DictVTable> for DictionaryScalarFnValuesPushDownRule {
type Parent = AnyScalarFn;
fn reduce_parent(
&self,
array: &DictArray,
parent: &ScalarFnArray,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let sig = parent.scalar_fn().signature();
if parent.scalar_fn().is::<Pack>() {
return Ok(None);
}
if parent.scalar_fn().is::<Cast>() {
return Ok(None);
}
if array.values().len() > array.codes().len() {
return Ok(None);
}
if !array.all_values_referenced && sig.is_fallible() {
tracing::trace!(
"Not pushing down fallible scalar function {} over dictionary with sparse codes {}",
parent.scalar_fn(),
array.encoding_id(),
);
return Ok(None);
}
if !parent
.children()
.iter()
.enumerate()
.all(|(idx, c)| idx == child_idx || c.is::<ConstantVTable>())
{
return Ok(None);
}
if array.codes.dtype().is_nullable() && !array.codes.all_valid()? && sig.is_null_sensitive()
{
tracing::trace!(
"Not pushing down null-sensitive scalar function {} over dictionary with null codes {}",
parent.scalar_fn(),
array.encoding_id(),
);
return Ok(None);
}
let values_len = array.values().len();
let mut new_children = Vec::with_capacity(parent.children().len());
for (idx, child) in parent.children().iter().enumerate() {
if idx == child_idx {
new_children.push(array.values().clone());
} else {
let scalar = child.as_::<ConstantVTable>().scalar().clone();
new_children.push(ConstantArray::new(scalar, values_len).into_array());
}
}
let new_values =
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)?
.into_array()
.optimize()?;
if sig.is_null_sensitive() && array.codes().dtype().is_nullable() {
let new_codes = array.codes().cast(array.codes().dtype().as_nonnullable())?;
let new_dict = unsafe { DictArray::new_unchecked(new_codes, new_values) }.into_array();
return Ok(Some(new_dict.cast(parent.dtype().clone())?));
}
Ok(Some(
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array(),
))
}
}
#[derive(Debug)]
struct DictionaryScalarFnCodesPullUpRule;
impl ArrayParentReduceRule<DictVTable> for DictionaryScalarFnCodesPullUpRule {
type Parent = AnyScalarFn;
fn reduce_parent(
&self,
array: &DictArray,
parent: &ScalarFnArray,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
if parent.children().len() < 2 {
return Ok(None);
}
if !parent.children().iter().enumerate().all(|(idx, c)| {
idx == child_idx
|| c.as_opt::<DictVTable>()
.is_some_and(|c| c.values().len() == array.values().len())
}) {
return Ok(None);
}
if !parent.children().iter().enumerate().all(|(idx, c)| {
idx == child_idx
|| c.as_opt::<DictVTable>()
.is_some_and(|c| c.codes().array_eq(array.codes(), Precision::Value))
}) {
return Ok(None);
}
let mut new_children = Vec::with_capacity(parent.children().len());
for (idx, child) in parent.children().iter().enumerate() {
if idx == child_idx {
new_children.push(array.values().clone());
} else {
new_children.push(child.as_::<DictVTable>().values().clone());
}
}
let new_values =
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, array.values.len())?
.into_array()
.optimize()?;
let new_dict =
unsafe { DictArray::new_unchecked(array.codes().clone(), new_values) }.into_array();
Ok(Some(new_dict))
}
}