use std::cmp::Ordering;
use std::ops::BitAnd;
use std::sync::{Arc, OnceLock};
use arrow_array::BooleanArray;
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder};
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexExpect, VortexResult};
use crate::array::ConstantArray;
use crate::arrow::FromArrowArray;
use crate::compute::scalar_at;
use crate::encoding::Encoding;
use crate::stats::{ArrayStatistics, Stat};
use crate::{ArrayDType, ArrayData, Canonical, IntoArrayData, IntoArrayVariant, IntoCanonical};
const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8;
pub trait FilterFn<Array> {
fn filter(&self, array: &Array, mask: &FilterMask) -> VortexResult<ArrayData>;
}
impl<E: Encoding> FilterFn<ArrayData> for E
where
E: FilterFn<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn filter(&self, array: &ArrayData, mask: &FilterMask) -> VortexResult<ArrayData> {
let (array_ref, encoding) = array.try_downcast_ref::<E>()?;
FilterFn::filter(encoding, array_ref, mask)
}
}
pub fn filter(array: &ArrayData, mask: &FilterMask) -> VortexResult<ArrayData> {
if mask.len() != array.len() {
vortex_bail!(
"mask.len() is {}, does not equal array.len() of {}",
mask.len(),
array.len()
);
}
let true_count = mask.true_count();
if true_count == 0 {
return Ok(Canonical::empty(array.dtype())?.into());
}
if true_count == mask.len() {
return Ok(array.clone());
}
let filtered = filter_impl(array, mask)?;
debug_assert_eq!(
filtered.len(),
true_count,
"Filter length mismatch {}",
array.encoding().id()
);
debug_assert_eq!(
filtered.dtype(),
array.dtype(),
"Filter dtype mismatch {}",
array.encoding().id()
);
Ok(filtered)
}
fn filter_impl(array: &ArrayData, mask: &FilterMask) -> VortexResult<ArrayData> {
if let Some(filter_fn) = array.encoding().filter_fn() {
return filter_fn.filter(array, mask);
}
if mask.true_count() == 1 && array.encoding().scalar_at_fn().is_some() {
let idx = mask.first().vortex_expect("true_count == 1");
return Ok(ConstantArray::new(scalar_at(array, idx)?, 1).into_array());
}
log::debug!(
"No filter implementation found for {}",
array.encoding().id(),
);
let array_ref = array.clone().into_arrow()?;
let mask_array = BooleanArray::new(mask.boolean_buffer().clone(), None);
let filtered = arrow_select::filter::filter(array_ref.as_ref(), &mask_array)?;
Ok(ArrayData::from_arrow(filtered, array.dtype().is_nullable()))
}
#[derive(Clone, Debug)]
pub struct FilterMask(Arc<Inner>);
#[derive(Debug)]
struct Inner {
buffer: OnceLock<BooleanBuffer>,
indices: OnceLock<Vec<usize>>,
slices: OnceLock<Vec<(usize, usize)>>,
len: usize,
true_count: usize,
selectivity: f64,
}
impl Inner {
fn buffer(&self) -> &BooleanBuffer {
self.buffer.get_or_init(|| {
if self.true_count == 0 {
return BooleanBuffer::new_unset(self.len);
}
if self.true_count == self.len {
return BooleanBuffer::new_set(self.len);
}
if let Some(indices) = self.indices.get() {
let mut buf = BooleanBufferBuilder::new(self.len);
buf.append_n(self.len, false);
indices.iter().for_each(|idx| buf.set_bit(*idx, true));
return BooleanBuffer::from(buf);
}
if let Some(slices) = self.slices.get() {
let mut buf = BooleanBufferBuilder::new(self.len);
for (start, end) in slices.iter().copied() {
buf.append_n(start - buf.len(), false);
buf.append_n(end - start, true);
}
if let Some((_, end)) = slices.last() {
buf.append_n(self.len - end, false);
}
debug_assert_eq!(buf.len(), self.len);
return BooleanBuffer::from(buf);
}
vortex_panic!("No mask representation found")
})
}
fn indices(&self) -> &[usize] {
self.indices.get_or_init(|| {
if self.true_count == 0 {
return vec![];
}
if self.true_count == self.len {
return (0..self.len).collect();
}
if let Some(buffer) = self.buffer.get() {
let mut indices = Vec::with_capacity(self.true_count);
indices.extend(buffer.set_indices());
return indices;
}
if let Some(slices) = self.slices.get() {
let mut indices = Vec::with_capacity(self.true_count);
indices.extend(slices.iter().flat_map(|(start, end)| *start..*end));
return indices;
}
vortex_panic!("No mask representation found")
})
}
fn slices(&self) -> &[(usize, usize)] {
self.slices.get_or_init(|| {
if self.true_count == self.len {
return vec![(0, self.len)];
}
if let Some(buffer) = self.buffer.get() {
return buffer.set_slices().collect();
}
if let Some(indices) = self.indices.get() {
let mut slices = Vec::with_capacity(self.true_count); let mut iter = indices.iter().copied();
let Some(first) = iter.next() else {
return slices;
};
let mut start = first;
let mut prev = first;
for curr in iter {
if curr != prev + 1 {
slices.push((start, prev + 1));
start = curr;
}
prev = curr;
}
slices.push((start, prev + 1));
return slices;
}
vortex_panic!("No mask representation found")
})
}
fn first(&self) -> Option<usize> {
if self.true_count == 0 {
return None;
}
if self.true_count == self.len {
return Some(0);
}
if let Some(buffer) = self.buffer.get() {
return buffer.set_indices().next();
}
if let Some(indices) = self.indices.get() {
return indices.first().copied();
}
if let Some(slices) = self.slices.get() {
return slices.first().map(|(start, _)| *start);
}
None
}
}
impl FilterMask {
pub fn new_true(length: usize) -> Self {
Self(Arc::new(Inner {
buffer: Default::default(),
indices: Default::default(),
slices: Default::default(),
len: length,
true_count: length,
selectivity: 1.0,
}))
}
pub fn new_false(length: usize) -> Self {
Self(Arc::new(Inner {
buffer: Default::default(),
indices: Default::default(),
slices: Default::default(),
len: length,
true_count: 0,
selectivity: 0.0,
}))
}
pub fn from_buffer(buffer: BooleanBuffer) -> Self {
let true_count = buffer.count_set_bits();
let len = buffer.len();
Self(Arc::new(Inner {
buffer: OnceLock::from(buffer),
indices: Default::default(),
slices: Default::default(),
len,
true_count,
selectivity: true_count as f64 / len as f64,
}))
}
pub fn from_indices(len: usize, vec: Vec<usize>) -> Self {
let true_count = vec.len();
assert!(vec.iter().all(|&idx| idx < len));
Self(Arc::new(Inner {
buffer: Default::default(),
indices: OnceLock::from(vec),
slices: Default::default(),
len,
true_count,
selectivity: true_count as f64 / len as f64,
}))
}
pub fn from_slices(len: usize, vec: Vec<(usize, usize)>) -> Self {
assert!(vec.iter().all(|&(b, e)| b < e && e <= len));
let true_count = vec.iter().map(|(b, e)| e - b).sum();
Self(Arc::new(Inner {
buffer: Default::default(),
indices: Default::default(),
slices: OnceLock::from(vec),
len,
true_count,
selectivity: true_count as f64 / len as f64,
}))
}
pub fn from_intersection_indices(
len: usize,
lhs: impl Iterator<Item = usize>,
rhs: impl Iterator<Item = usize>,
) -> Self {
let mut intersection = Vec::with_capacity(len);
let mut lhs = lhs.peekable();
let mut rhs = rhs.peekable();
while let (Some(&l), Some(&r)) = (lhs.peek(), rhs.peek()) {
match l.cmp(&r) {
Ordering::Less => {
lhs.next();
}
Ordering::Greater => {
rhs.next();
}
Ordering::Equal => {
intersection.push(l);
lhs.next();
rhs.next();
}
}
}
Self::from_indices(len, intersection)
}
#[inline]
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.0.len
}
#[inline]
pub fn true_count(&self) -> usize {
self.0.true_count
}
#[inline]
pub fn false_count(&self) -> usize {
self.len() - self.true_count()
}
#[inline]
pub fn selectivity(&self) -> f64 {
self.0.selectivity
}
pub fn boolean_buffer(&self) -> &BooleanBuffer {
self.0.buffer()
}
pub fn indices(&self) -> &[usize] {
self.0.indices()
}
pub fn slices(&self) -> &[(usize, usize)] {
self.0.slices()
}
pub fn first(&self) -> Option<usize> {
self.0.first()
}
pub fn iter(&self) -> FilterIter {
if self.selectivity() > FILTER_SLICES_SELECTIVITY_THRESHOLD {
FilterIter::Slices(self.slices())
} else {
FilterIter::Indices(self.indices())
}
}
pub fn slice(&self, offset: usize, length: usize) -> Self {
if self.true_count() == 0 {
return Self::new_false(length);
}
if self.true_count() == self.len() {
return Self::new_true(length);
}
if let Some(buffer) = self.0.buffer.get() {
return Self::from_buffer(buffer.slice(offset, length));
}
let end = offset + length;
if let Some(indices) = self.0.indices.get() {
let indices = indices
.iter()
.copied()
.filter(|&idx| offset <= idx && idx < end)
.map(|idx| idx - offset)
.collect();
return Self::from_indices(length, indices);
}
if let Some(slices) = self.0.slices.get() {
let slices = slices
.iter()
.copied()
.filter(|(s, e)| *s < end && *e > offset)
.map(|(s, e)| (s.max(offset), e.min(end)))
.collect();
return Self::from_slices(length, slices);
}
vortex_panic!("No mask representation found")
}
pub fn intersect_by_rank(&self, mask: &FilterMask) -> FilterMask {
assert_eq!(self.true_count(), mask.len());
if mask.true_count() == mask.len() {
return self.clone();
}
if mask.true_count() == 0 {
return Self::new_false(self.len());
}
let indices = self.0.indices();
Self::from_indices(
self.len(),
mask.indices()
.iter()
.map(|idx|
unsafe{*indices.get_unchecked(*idx)})
.collect(),
)
}
}
pub enum FilterIter<'a> {
Indices(&'a [usize]),
Slices(&'a [(usize, usize)]),
}
impl PartialEq for FilterMask {
fn eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
if self.true_count() != other.true_count() {
return false;
}
if self.true_count() == 0 || self.true_count() == self.len() {
return true;
}
if let (Some(buffer), Some(other)) = (self.0.buffer.get(), other.0.buffer.get()) {
return buffer == other;
}
if let (Some(indices), Some(other)) = (self.0.indices.get(), other.0.indices.get()) {
return indices == other;
}
if let (Some(slices), Some(other)) = (self.0.slices.get(), other.0.slices.get()) {
return slices == other;
}
self.boolean_buffer() == other.boolean_buffer()
}
}
impl Eq for FilterMask {}
impl BitAnd for &FilterMask {
type Output = FilterMask;
fn bitand(self, rhs: Self) -> Self::Output {
if self.len() != rhs.len() {
vortex_panic!("FilterMasks must have the same length");
}
if self.true_count() == 0 || rhs.true_count() == 0 {
return FilterMask::new_false(self.len());
}
if self.true_count() == self.len() {
return rhs.clone();
}
if rhs.true_count() == self.len() {
return self.clone();
}
if let (Some(lhs), Some(rhs)) = (self.0.buffer.get(), rhs.0.buffer.get()) {
return FilterMask::from_buffer(lhs & rhs);
}
if let (Some(lhs), Some(rhs)) = (self.0.indices.get(), rhs.0.indices.get()) {
return FilterMask::from_intersection_indices(
self.len(),
lhs.iter().copied(),
rhs.iter().copied(),
);
}
FilterMask::from_buffer(self.boolean_buffer() & rhs.boolean_buffer())
}
}
impl TryFrom<ArrayData> for FilterMask {
type Error = VortexError;
fn try_from(array: ArrayData) -> Result<Self, Self::Error> {
if array.dtype() != &DType::Bool(Nullability::NonNullable) {
vortex_bail!(
"mask must be non-nullable bool, has dtype {}",
array.dtype(),
);
}
if let Some(true_count) = array.statistics().get_as_cast::<u64>(Stat::TrueCount) {
let len = array.len();
if true_count == 0 {
return Ok(Self::new_false(len));
}
if true_count == len as u64 {
return Ok(Self::new_true(len));
}
}
Ok(Self::from_buffer(array.into_bool()?.boolean_buffer()))
}
}
impl From<BooleanBuffer> for FilterMask {
fn from(value: BooleanBuffer) -> Self {
Self::from_buffer(value)
}
}
impl FromIterator<bool> for FilterMask {
fn from_iter<T: IntoIterator<Item = bool>>(iter: T) -> Self {
Self::from_buffer(BooleanBuffer::from_iter(iter))
}
}
#[cfg(test)]
mod test {
use itertools::Itertools;
use super::*;
use crate::array::{BoolArray, PrimitiveArray};
use crate::compute::filter::filter;
use crate::{IntoArrayData, IntoCanonical};
#[test]
fn test_filter() {
let items =
PrimitiveArray::from_option_iter([Some(0i32), None, Some(1i32), None, Some(2i32)])
.into_array();
let mask = FilterMask::try_from(
BoolArray::from_iter([true, false, true, false, true]).into_array(),
)
.unwrap();
let filtered = filter(&items, &mask).unwrap();
assert_eq!(
filtered
.into_canonical()
.unwrap()
.into_primitive()
.unwrap()
.as_slice::<i32>(),
&[0i32, 1i32, 2i32]
);
}
#[test]
fn filter_mask_all_true() {
let mask = FilterMask::new_true(5);
assert_eq!(mask.len(), 5);
assert_eq!(mask.true_count(), 5);
assert_eq!(mask.selectivity(), 1.0);
assert_eq!(mask.indices(), &[0, 1, 2, 3, 4]);
assert_eq!(mask.slices(), &[(0, 5)]);
assert_eq!(mask.boolean_buffer(), &BooleanBuffer::new_set(5));
}
#[test]
fn filter_mask_all_false() {
let mask = FilterMask::new_false(5);
assert_eq!(mask.len(), 5);
assert_eq!(mask.true_count(), 0);
assert_eq!(mask.selectivity(), 0.0);
assert_eq!(mask.indices(), &[] as &[usize]);
assert_eq!(mask.slices(), &[]);
assert_eq!(mask.boolean_buffer(), &BooleanBuffer::new_unset(5));
}
#[test]
fn filter_mask_from() {
let masks = [
FilterMask::from_indices(5, vec![0, 2, 3]),
FilterMask::from_slices(5, vec![(0, 1), (2, 4)]),
FilterMask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false])),
];
for mask in &masks {
assert_eq!(mask.len(), 5);
assert_eq!(mask.true_count(), 3);
assert_eq!(mask.selectivity(), 0.6);
assert_eq!(mask.indices(), &[0, 2, 3]);
assert_eq!(mask.slices(), &[(0, 1), (2, 4)]);
assert_eq!(
&mask.boolean_buffer().iter().collect_vec(),
&[true, false, true, true, false]
);
}
}
#[test]
fn filter_mask_eq() {
assert_eq!(
FilterMask::new_true(5),
FilterMask::from_buffer(BooleanBuffer::new_set(5))
);
assert_eq!(
FilterMask::new_false(5),
FilterMask::from_buffer(BooleanBuffer::new_unset(5))
);
assert_eq!(
FilterMask::from_indices(5, vec![0, 2, 3]),
FilterMask::from_slices(5, vec![(0, 1), (2, 4)])
);
assert_eq!(
FilterMask::from_indices(5, vec![0, 2, 3]),
FilterMask::from_buffer(BooleanBuffer::from_iter([true, false, true, true, false]))
);
}
#[test]
fn filter_mask_intersect_all_as_bit_and() {
let this =
FilterMask::from_buffer(BooleanBuffer::from_iter(vec![true, true, true, true, true]));
let mask = FilterMask::from_buffer(BooleanBuffer::from_iter(vec![
false, true, false, true, true,
]));
assert_eq!(
this.intersect_by_rank(&mask),
FilterMask::from_indices(5, vec![1, 3, 4])
);
}
#[test]
fn filter_mask_intersect_all_true() {
let this = FilterMask::from_buffer(BooleanBuffer::from_iter(vec![
false, false, true, true, true,
]));
let mask = FilterMask::from_buffer(BooleanBuffer::from_iter(vec![true, true, true]));
assert_eq!(
this.intersect_by_rank(&mask),
FilterMask::from_indices(5, vec![2, 3, 4])
);
}
#[test]
fn filter_mask_intersect_true() {
let this = FilterMask::from_buffer(BooleanBuffer::from_iter(vec![
true, false, false, true, true,
]));
let mask = FilterMask::from_buffer(BooleanBuffer::from_iter(vec![true, false, true]));
assert_eq!(
this.intersect_by_rank(&mask),
FilterMask::from_indices(5, vec![0, 4])
);
}
#[test]
fn filter_mask_intersect_false() {
let this = FilterMask::from_buffer(BooleanBuffer::from_iter(vec![
true, false, false, true, true,
]));
let mask = FilterMask::from_buffer(BooleanBuffer::from_iter(vec![false, false, false]));
assert_eq!(
this.intersect_by_rank(&mask),
FilterMask::from_indices(5, vec![])
);
}
}