use std::sync::LazyLock;
use arcref::ArcRef;
use arrow_array::BooleanArray;
use vortex_dtype::DType;
use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err};
use vortex_mask::Mask;
use vortex_scalar::Scalar;
use crate::arrays::ConstantArray;
use crate::arrow::{FromArrowArray, IntoArrowArray};
use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, cast};
use crate::vtable::VTable;
use crate::{Array, ArrayRef, IntoArray};
static MASK_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
let compute = ComputeFn::new("mask".into(), ArcRef::new_ref(&MaskFn));
for kernel in inventory::iter::<MaskKernelRef> {
compute.register_kernel(kernel.0.clone());
}
compute
});
pub(crate) fn warm_up_vtable() -> usize {
MASK_FN.kernels().len()
}
pub fn mask(array: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
MASK_FN
.invoke(&InvocationArgs {
inputs: &[array.into(), mask.into()],
options: &(),
})?
.unwrap_array()
}
pub struct MaskKernelRef(ArcRef<dyn Kernel>);
inventory::collect!(MaskKernelRef);
pub trait MaskKernel: VTable {
fn mask(&self, array: &Self::Array, mask: &Mask) -> VortexResult<ArrayRef>;
}
#[derive(Debug)]
pub struct MaskKernelAdapter<V: VTable>(pub V);
impl<V: VTable + MaskKernel> MaskKernelAdapter<V> {
pub const fn lift(&'static self) -> MaskKernelRef {
MaskKernelRef(ArcRef::new_ref(self))
}
}
impl<V: VTable + MaskKernel> Kernel for MaskKernelAdapter<V> {
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
let inputs = MaskArgs::try_from(args)?;
let Some(array) = inputs.array.as_opt::<V>() else {
return Ok(None);
};
Ok(Some(V::mask(&self.0, array, inputs.mask)?.into()))
}
}
struct MaskFn;
impl ComputeFnVTable for MaskFn {
fn invoke(
&self,
args: &InvocationArgs,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Output> {
let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
if matches!(mask, Mask::AllFalse(_)) {
return Ok(cast(array, &array.dtype().as_nullable())?.into());
}
if matches!(mask, Mask::AllTrue(_)) {
return Ok(ConstantArray::new(
Scalar::null(array.dtype().clone().as_nullable()),
array.len(),
)
.into_array()
.into());
}
if array.all_invalid() {
return Ok(array.to_array().into());
}
for kernel in kernels {
if let Some(output) = kernel.invoke(args)? {
return Ok(output);
}
}
if let Some(output) = array.invoke(&MASK_FN, args)? {
return Ok(output);
}
log::debug!("No mask implementation found for {}", array.encoding_id());
let array_ref = array.to_array().into_arrow_preferred()?;
let mask = BooleanArray::new(mask.to_boolean_buffer(), None);
let masked = arrow_select::nullif::nullif(array_ref.as_ref(), &mask)?;
Ok(ArrayRef::from_arrow(masked.as_ref(), true).into())
}
fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
let MaskArgs { array, .. } = MaskArgs::try_from(args)?;
Ok(array.dtype().as_nullable())
}
fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
let MaskArgs { array, mask } = MaskArgs::try_from(args)?;
if mask.len() != array.len() {
vortex_bail!(
"mask.len() is {}, does not equal array.len() of {}",
mask.len(),
array.len()
);
}
Ok(mask.len())
}
fn is_elementwise(&self) -> bool {
true
}
}
struct MaskArgs<'a> {
array: &'a dyn Array,
mask: &'a Mask,
}
impl<'a> TryFrom<&InvocationArgs<'a>> for MaskArgs<'a> {
type Error = VortexError;
fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
if value.inputs.len() != 2 {
vortex_bail!("Mask function requires 2 arguments");
}
let array = value.inputs[0]
.array()
.ok_or_else(|| vortex_err!("Expected input 0 to be an array"))?;
let mask = value.inputs[1]
.mask()
.ok_or_else(|| vortex_err!("Expected input 1 to be a mask"))?;
Ok(MaskArgs { array, mask })
}
}