use std::ops::Range;
use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::IntoArray;
use crate::arrays::DecimalArray;
use crate::arrays::DecimalVTable;
use crate::arrays::MaskedArray;
use crate::arrays::MaskedVTable;
use crate::arrays::slice::SliceReduce;
use crate::arrays::slice::SliceReduceAdaptor;
use crate::match_each_decimal_value_type;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::optimizer::rules::ParentRuleSet;
use crate::scalar_fn::fns::mask::MaskReduceAdaptor;
use crate::vtable::ValidityHelper;
pub(crate) static RULES: ParentRuleSet<DecimalVTable> = ParentRuleSet::new(&[
ParentRuleSet::lift(&DecimalMaskedValidityRule),
ParentRuleSet::lift(&MaskReduceAdaptor(DecimalVTable)),
ParentRuleSet::lift(&SliceReduceAdaptor(DecimalVTable)),
]);
#[derive(Default, Debug)]
pub struct DecimalMaskedValidityRule;
impl ArrayParentReduceRule<DecimalVTable> for DecimalMaskedValidityRule {
type Parent = MaskedVTable;
fn reduce_parent(
&self,
array: &DecimalArray,
parent: &MaskedArray,
_child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let masked_array = match_each_decimal_value_type!(array.values_type(), |D| {
unsafe {
DecimalArray::new_unchecked(
array.buffer::<D>(),
array.decimal_dtype(),
array.validity().clone().and(parent.validity().clone())?,
)
}
.into_array()
});
Ok(Some(masked_array))
}
}
impl SliceReduce for DecimalVTable {
fn slice(array: &Self::Array, range: Range<usize>) -> VortexResult<Option<ArrayRef>> {
let result = match_each_decimal_value_type!(array.values_type(), |D| {
let sliced = array.buffer::<D>().slice(range.clone());
let validity = array.validity().clone().slice(range)?;
unsafe { DecimalArray::new_unchecked(sliced, array.decimal_dtype(), validity) }
.into_array()
});
Ok(Some(result))
}
}