use std::iter::TrustedLen;
use std::sync::{Arc, OnceLock};
use arrow_array::BooleanArray;
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, MutableBuffer};
use num_traits::AsPrimitive;
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect, VortexResult};
use crate::array::{BoolArray, ConstantArray};
use crate::arrow::FromArrowArray;
use crate::compute::scalar_at;
use crate::encoding::Encoding;
use crate::stats::ArrayStatistics;
use crate::{ArrayDType, ArrayData, Canonical, IntoArrayData, IntoCanonical};
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
pub trait FilterFn<Array> {
fn filter(&self, array: &Array, mask: FilterMask) -> VortexResult<ArrayData>;
}
impl<E: Encoding> FilterFn<ArrayData> for E
where
E: FilterFn<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn filter(&self, array: &ArrayData, mask: FilterMask) -> VortexResult<ArrayData> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
FilterFn::filter(encoding, array_ref, mask)
}
}
pub fn filter(array: &ArrayData, mask: FilterMask) -> VortexResult<ArrayData> {
if mask.len() != array.len() {
vortex_bail!(
"mask.len() is {}, does not equal array.len() of {}",
mask.len(),
array.len()
);
}
if mask.true_count() == 0 {
return Ok(Canonical::empty(array.dtype())?.into());
}
if mask.true_count() == mask.len() {
return Ok(array.clone());
}
if let Some(filter_fn) = array.encoding().filter_fn() {
let true_count = mask.true_count();
let result = filter_fn.filter(array, mask)?;
if array.dtype() != result.dtype() {
vortex_bail!(
"FilterFn {} changed array dtype from {} to {}",
array.encoding().id(),
array.dtype(),
result.dtype()
);
}
if true_count != result.len() {
vortex_bail!(
"FilterFn {} returned incorrect length: expected {}, got {}",
array.encoding().id(),
true_count,
result.len()
);
}
Ok(result)
} else {
if mask.true_count() == 1 && array.encoding().scalar_at_fn().is_some() {
let idx = mask.indices()?[0];
return Ok(ConstantArray::new(scalar_at(array, idx)?, 1).into_array());
}
log::debug!(
"No filter implementation found for {}",
array.encoding().id(),
);
let array_ref = array.clone().into_arrow()?;
let mask_array = BooleanArray::new(mask.to_boolean_buffer()?, None);
let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?;
Ok(ArrayData::from_arrow(filtered, array.dtype().is_nullable()))
}
}
#[derive(Debug)]
pub struct FilterMask {
array: ArrayData,
true_count: usize,
range_selectivity: f64,
indices: Arc<OnceLock<Vec<usize>>>,
slices: Arc<OnceLock<Vec<(usize, usize)>>>,
buffer: Arc<OnceLock<BooleanBuffer>>,
}
impl Clone for FilterMask {
fn clone(&self) -> Self {
if self.range_selectivity > FILTER_SLICES_SELECTIVITY_THRESHOLD {
let _: VortexResult<_> = self
.slices
.get_or_try_init(|| Ok(self.boolean_buffer()?.set_slices().collect()));
} else {
let _: VortexResult<_> = self.indices();
}
Self {
array: self.array.clone(),
true_count: self.true_count,
range_selectivity: self.range_selectivity,
indices: self.indices.clone(),
slices: self.slices.clone(),
buffer: self.buffer.clone(),
}
}
}
pub struct BitIndexIterator<'a> {
inner: arrow_buffer::bit_iterator::BitIndexIterator<'a>,
index: usize,
trusted_len: usize,
}
impl<'a> BitIndexIterator<'a> {
pub fn new(
inner: arrow_buffer::bit_iterator::BitIndexIterator<'a>,
trusted_len: usize,
) -> Self {
Self {
inner,
index: 0,
trusted_len,
}
}
}
impl<'a> Iterator for BitIndexIterator<'a> {
type Item = usize;
fn next(&mut self) -> Option<Self::Item> {
self.index += 1;
self.inner.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.trusted_len - self.index;
(remaining, Some(remaining))
}
}
unsafe impl<'a> TrustedLen for BitIndexIterator<'a> {}
impl<'a> ExactSizeIterator for BitIndexIterator<'a> {}
pub enum FilterIter<'a> {
Indices(&'a [usize]),
IndicesIter(BitIndexIterator<'a>),
Slices(&'a [(usize, usize)]),
SlicesIter(arrow_buffer::bit_iterator::BitSliceIterator<'a>),
}
impl FilterMask {
pub fn from_indices<V: AsPrimitive<usize>, I: IntoIterator<Item = V>>(
length: usize,
indices: I,
) -> Self {
let mut buffer = MutableBuffer::new_null(length);
indices
.into_iter()
.for_each(|idx| arrow_buffer::bit_util::set_bit(&mut buffer, idx.as_()));
Self::from(BooleanBufferBuilder::new_from_buffer(buffer, length).finish())
}
pub fn len(&self) -> usize {
self.array.len()
}
pub fn is_empty(&self) -> bool {
self.array.is_empty()
}
pub fn true_count(&self) -> usize {
self.true_count
}
pub fn false_count(&self) -> usize {
self.array.len() - self.true_count
}
pub fn selectivity(&self) -> f64 {
self.true_count as f64 / self.len() as f64
}
pub fn range_selectivity(&self) -> f64 {
self.range_selectivity
}
pub fn to_boolean_buffer(&self) -> VortexResult<BooleanBuffer> {
log::debug!(
"FilterMask: len {} selectivity: {} true_count: {}",
self.len(),
self.range_selectivity(),
self.true_count,
);
self.boolean_buffer().cloned()
}
fn boolean_buffer(&self) -> VortexResult<&BooleanBuffer> {
self.buffer.get_or_try_init(|| {
Ok(self
.array
.clone()
.into_canonical()?
.into_bool()?
.boolean_buffer())
})
}
fn indices(&self) -> VortexResult<&[usize]> {
self.indices
.get_or_try_init(|| {
let mut indices = Vec::with_capacity(self.true_count());
indices.extend(self.boolean_buffer()?.set_indices());
Ok(indices)
})
.map(|v| v.as_slice())
}
pub fn iter(&self) -> VortexResult<FilterIter> {
Ok(
if self.range_selectivity > FILTER_SLICES_SELECTIVITY_THRESHOLD {
if let Some(slices) = self.slices.get() {
FilterIter::Slices(slices.as_slice())
} else {
FilterIter::SlicesIter(self.boolean_buffer()?.set_slices())
}
} else {
if let Some(indices) = self.indices.get() {
FilterIter::Indices(indices.as_slice())
} else {
FilterIter::IndicesIter(BitIndexIterator::new(
self.boolean_buffer()?.set_indices(),
self.true_count,
))
}
},
)
}
#[deprecated(note = "Move to using iter() instead")]
pub fn iter_slices(&self) -> VortexResult<impl Iterator<Item = (usize, usize)> + '_> {
Ok(self.boolean_buffer()?.set_slices())
}
#[deprecated(note = "Move to using iter() instead")]
pub fn iter_indices(&self) -> VortexResult<impl Iterator<Item = usize> + '_> {
Ok(self.boolean_buffer()?.set_indices())
}
}
impl TryFrom<ArrayData> for FilterMask {
type Error = VortexError;
fn try_from(array: ArrayData) -> Result<Self, Self::Error> {
if array.dtype() != &DType::Bool(Nullability::NonNullable) {
vortex_bail!(
"mask must be non-nullable bool, has dtype {}",
array.dtype(),
);
}
let true_count = array
.statistics()
.compute_true_count()
.ok_or_else(|| vortex_err!("Failed to compute true count for boolean array"))?;
let selectivity = true_count as f64 / array.len() as f64;
Ok(Self {
array,
true_count,
range_selectivity: selectivity,
indices: Arc::new(OnceLock::new()),
slices: Arc::new(OnceLock::new()),
buffer: Arc::new(OnceLock::new()),
})
}
}
impl From<BooleanBuffer> for FilterMask {
fn from(value: BooleanBuffer) -> Self {
Self::try_from(BoolArray::from(value).into_array())
.vortex_expect("Failed to convert BooleanBuffer to FilterMask")
}
}
impl FromIterator<bool> for FilterMask {
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
Self::from(BooleanBuffer::from_iter(iter))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::array::{BoolArray, PrimitiveArray};
use crate::compute::filter::filter;
use crate::{IntoArrayData, IntoCanonical};
#[test]
fn test_filter() {
let items =
PrimitiveArray::from_nullable_vec(vec![Some(0i32), None, Some(1i32), None, Some(2i32)])
.into_array();
let mask = FilterMask::try_from(
BoolArray::from_iter([true, false, true, false, true]).into_array(),
)
.unwrap();
let filtered = filter(&items, mask).unwrap();
assert_eq!(
filtered
.into_canonical()
.unwrap()
.into_primitive()
.unwrap()
.into_maybe_null_slice::<i32>(),
vec![0i32, 1i32, 2i32]
);
}
}