use arrow_array::BooleanArray;
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, VortexError, VortexExpect, VortexResult};
use vortex_mask::Mask;
use crate::array::ConstantArray;
use crate::arrow::{FromArrowArray, IntoArrowArray};
use crate::compute::scalar_at;
use crate::encoding::Encoding;
use crate::stats::Stat;
use crate::{Array, Canonical, IntoArray, IntoArrayVariant};
pub trait FilterFn<A> {
fn filter(&self, array: &A, mask: &Mask) -> VortexResult<Array>;
}
impl<E: Encoding> FilterFn<Array> for E
where
E: FilterFn<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a Array, Error = VortexError>,
{
fn filter(&self, array: &Array, mask: &Mask) -> VortexResult<Array> {
let (array_ref, encoding) = array.try_downcast_ref::<E>()?;
FilterFn::filter(encoding, array_ref, mask)
}
}
pub fn filter(array: &Array, mask: &Mask) -> VortexResult<Array> {
if mask.len() != array.len() {
vortex_bail!(
"mask.len() is {}, does not equal array.len() of {}",
mask.len(),
array.len()
);
}
let true_count = mask.true_count();
if true_count == 0 {
return Ok(Canonical::empty(array.dtype()).into());
}
if true_count == mask.len() {
return Ok(array.clone());
}
let filtered = filter_impl(array, mask)?;
debug_assert_eq!(
filtered.len(),
true_count,
"Filter length mismatch {}",
array.encoding()
);
debug_assert_eq!(
filtered.dtype(),
array.dtype(),
"Filter dtype mismatch {}",
array.encoding()
);
Ok(filtered)
}
fn filter_impl(array: &Array, mask: &Mask) -> VortexResult<Array> {
let values = match &mask {
Mask::AllTrue(_) => return Ok(array.clone()),
Mask::AllFalse(_) => return Ok(Canonical::empty(array.dtype()).into_array()),
Mask::Values(values) => values,
};
if let Some(filter_fn) = array.vtable().filter_fn() {
let result = filter_fn.filter(array, mask)?;
debug_assert_eq!(result.len(), mask.true_count());
return Ok(result);
}
if mask.true_count() == 1 && array.vtable().scalar_at_fn().is_some() {
let idx = mask.first().vortex_expect("true_count == 1");
return Ok(ConstantArray::new(scalar_at(array, idx)?, 1).into_array());
}
log::debug!("No filter implementation found for {}", array.encoding(),);
let array_ref = array.clone().into_arrow_preferred()?;
let mask_array = BooleanArray::new(values.boolean_buffer().clone(), None);
let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?;
Ok(Array::from_arrow(filtered, array.dtype().is_nullable()))
}
impl TryFrom<Array> for Mask {
type Error = VortexError;
fn try_from(array: Array) -> Result<Self, Self::Error> {
if array.dtype() != &DType::Bool(Nullability::NonNullable) {
vortex_bail!(
"mask must be non-nullable bool, has dtype {}",
array.dtype(),
);
}
if let Some(true_count) = array.statistics().get_as::<u64>(Stat::TrueCount) {
let len = array.len();
if true_count == 0 {
return Ok(Self::new_false(len));
}
if true_count == len as u64 {
return Ok(Self::new_true(len));
}
}
Ok(Self::from_buffer(array.into_bool()?.boolean_buffer()))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::array::{BoolArray, PrimitiveArray};
use crate::compute::filter::filter;
use crate::IntoArray;
#[test]
fn test_filter() {
let items =
PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)])
.into_array();
let mask =
Mask::try_from(BoolArray::from_iter([true, false, true, false, true]).into_array())
.unwrap();
let filtered = filter(&items, &mask).unwrap();
assert_eq!(
filtered.into_primitive().unwrap().as_slice::<i32>(),
&[0i32, 1i32, 2i32]
);
}
}