use std::sync::LazyLock;
use arcref::ArcRef;
use arrow_array::BooleanArray;
use vortex_dtype::DType;
use vortex_error::{VortexError, VortexExpect, VortexResult, vortex_bail, vortex_err};
use vortex_mask::Mask;
use vortex_scalar::Scalar;
use crate::arrays::ConstantArray;
use crate::arrow::{FromArrowArray, IntoArrowArray};
use crate::compute::{ComputeFn, ComputeFnVTable, InvocationArgs, Kernel, Output, fill_null};
use crate::vtable::VTable;
use crate::{Array, ArrayRef, Canonical, IntoArray, ToCanonical};
static FILTER_FN: LazyLock<ComputeFn> = LazyLock::new(|| {
let compute = ComputeFn::new("filter".into(), ArcRef::new_ref(&Filter));
for kernel in inventory::iter::<FilterKernelRef> {
compute.register_kernel(kernel.0.clone());
}
compute
});
pub(crate) fn warm_up_vtable() -> usize {
FILTER_FN.kernels().len()
}
pub fn filter(array: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
FILTER_FN
.invoke(&InvocationArgs {
inputs: &[array.into(), mask.into()],
options: &(),
})?
.unwrap_array()
}
struct Filter;
impl ComputeFnVTable for Filter {
fn invoke(
&self,
args: &InvocationArgs,
kernels: &[ArcRef<dyn Kernel>],
) -> VortexResult<Output> {
let FilterArgs { array, mask } = FilterArgs::try_from(args)?;
debug_assert_eq!(
array.len(),
mask.len(),
"Tried to filter an array via a mask with the wrong length"
);
let true_count = mask.true_count();
if true_count == 0 {
return Ok(Canonical::empty(array.dtype()).into_array().into());
}
if true_count == mask.len() {
return Ok(array.to_array().into());
}
if array.validity_mask().true_count() == 0 {
return Ok(
ConstantArray::new(Scalar::null(array.dtype().clone()), true_count)
.into_array()
.into(),
);
}
for kernel in kernels {
if let Some(output) = kernel.invoke(args)? {
return Ok(output);
}
}
if let Some(output) = array.invoke(&FILTER_FN, args)? {
return Ok(output);
}
if mask.true_count() == 1 {
let idx = mask.first().vortex_expect("true_count == 1");
return Ok(ConstantArray::new(array.scalar_at(idx), 1)
.into_array()
.into());
}
log::debug!("No filter implementation found for {}", array.encoding_id(),);
if !array.is_canonical() {
let canonical = array.to_canonical().into_array();
return filter(&canonical, mask).map(Into::into);
};
vortex_bail!(
"No filter implementation found for array {}",
array.encoding()
)
}
fn return_dtype(&self, args: &InvocationArgs) -> VortexResult<DType> {
Ok(FilterArgs::try_from(args)?.array.dtype().clone())
}
fn return_len(&self, args: &InvocationArgs) -> VortexResult<usize> {
let FilterArgs { array, mask } = FilterArgs::try_from(args)?;
if mask.len() != array.len() {
vortex_bail!(
"mask.len() is {}, does not equal array.len() of {}",
mask.len(),
array.len()
);
}
Ok(mask.true_count())
}
fn is_elementwise(&self) -> bool {
false
}
}
struct FilterArgs<'a> {
array: &'a dyn Array,
mask: &'a Mask,
}
impl<'a> TryFrom<&InvocationArgs<'a>> for FilterArgs<'a> {
type Error = VortexError;
fn try_from(value: &InvocationArgs<'a>) -> Result<Self, Self::Error> {
if value.inputs.len() != 2 {
vortex_bail!("Expected 2 inputs, found {}", value.inputs.len());
}
let array = value.inputs[0]
.array()
.ok_or_else(|| vortex_err!("Expected first input to be an array"))?;
let mask = value.inputs[1]
.mask()
.ok_or_else(|| vortex_err!("Expected second input to be a mask"))?;
Ok(Self { array, mask })
}
}
pub struct FilterKernelRef(pub ArcRef<dyn Kernel>);
inventory::collect!(FilterKernelRef);
pub trait FilterKernel: VTable {
fn filter(&self, array: &Self::Array, selection_mask: &Mask) -> VortexResult<ArrayRef>;
}
#[derive(Debug)]
pub struct FilterKernelAdapter<V: VTable>(pub V);
impl<V: VTable + FilterKernel> FilterKernelAdapter<V> {
pub const fn lift(&'static self) -> FilterKernelRef {
FilterKernelRef(ArcRef::new_ref(self))
}
}
impl<V: VTable + FilterKernel> Kernel for FilterKernelAdapter<V> {
fn invoke(&self, args: &InvocationArgs) -> VortexResult<Option<Output>> {
let inputs = FilterArgs::try_from(args)?;
let Some(array) = inputs.array.as_opt::<V>() else {
return Ok(None);
};
let filtered = V::filter(&self.0, array, inputs.mask)?;
Ok(Some(filtered.into()))
}
}
impl dyn Array + '_ {
pub fn try_to_mask_fill_null_false(&self) -> VortexResult<Mask> {
if !matches!(self.dtype(), DType::Bool(_)) {
vortex_bail!("mask must be bool array, has dtype {}", self.dtype());
}
let array = fill_null(self, &Scalar::bool(false, self.dtype().nullability()))?;
Ok(array.to_bool().to_mask_fill_null_false())
}
}
pub fn arrow_filter_fn(array: &dyn Array, mask: &Mask) -> VortexResult<ArrayRef> {
let values = match &mask {
Mask::Values(values) => values,
Mask::AllTrue(_) | Mask::AllFalse(_) => unreachable!("check in filter invoke"),
};
let array_ref = array.to_array().into_arrow_preferred()?;
let mask_array = BooleanArray::new(values.boolean_buffer().clone(), None);
let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?;
Ok(ArrayRef::from_arrow(
filtered.as_ref(),
array.dtype().is_nullable(),
))
}
#[cfg(test)]
mod test {
use super::*;
use crate::arrays::PrimitiveArray;
use crate::canonical::ToCanonical;
use crate::compute::filter::filter;
#[test]
fn test_filter() {
let items =
PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)])
.into_array();
let mask = Mask::from_iter([true, false, true, false, true]);
let filtered = filter(&items, &mask).unwrap();
assert_eq!(
filtered.to_primitive().as_slice::<i32>(),
&[0i32, 1i32, 2i32]
);
}
}