use vortex_array::ArrayRef;
use vortex_array::ArrayView;
use vortex_array::IntoArray;
use vortex_array::arrays::Constant;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::ScalarFnArray;
use vortex_array::arrays::scalar_fn::AnyScalarFn;
use vortex_array::arrays::scalar_fn::ScalarFnArrayExt;
use vortex_array::arrays::scalar_fn::ScalarFnVTable;
use vortex_array::dtype::DType;
use vortex_array::optimizer::rules::ArrayParentReduceRule;
use vortex_array::optimizer::rules::ParentRuleSet;
use vortex_array::scalar_fn::fns::cast::CastReduceAdaptor;
use vortex_array::scalar_fn::fns::fill_null::FillNullReduceAdaptor;
use vortex_error::VortexResult;
use crate::RunEnd;
use crate::array::RunEndArrayExt;
pub(super) const RULES: ParentRuleSet<RunEnd> = ParentRuleSet::new(&[
ParentRuleSet::lift(&CastReduceAdaptor(RunEnd)),
ParentRuleSet::lift(&RunEndScalarFnRule),
ParentRuleSet::lift(&FillNullReduceAdaptor(RunEnd)),
]);
#[derive(Debug)]
pub(crate) struct RunEndScalarFnRule;
impl ArrayParentReduceRule<RunEnd> for RunEndScalarFnRule {
type Parent = AnyScalarFn;
fn reduce_parent(
&self,
run_end: ArrayView<'_, RunEnd>,
parent: ArrayView<'_, ScalarFnVTable>,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
for (idx, child) in parent.iter_children().enumerate() {
if idx == child_idx {
continue;
}
if !child.is::<Constant>() {
return Ok(None);
}
}
if !matches!(parent.dtype(), DType::Bool(_) | DType::Primitive(..)) {
return Ok(None);
}
let values_len = run_end.values().len();
let mut new_children: Vec<ArrayRef> = parent.children().to_vec();
for (idx, child) in new_children.iter_mut().enumerate() {
if idx == child_idx {
*child = run_end.values().clone();
continue;
}
let constant = child.as_::<Constant>();
*child = ConstantArray::new(constant.scalar().clone(), values_len).into_array();
}
let new_values =
ScalarFnArray::try_new(parent.scalar_fn().clone(), new_children, values_len)?
.into_array();
Ok(Some(
unsafe {
RunEnd::new_unchecked(
run_end.ends().clone(),
new_values,
run_end.offset(),
run_end.len(),
)
}
.into_array(),
))
}
}