use std::cmp::min;
use std::ops::AddAssign;
use num_traits::AsPrimitive;
use vortex_array::ArrayRef;
use vortex_array::ArrayView;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::filter::FilterKernel;
use vortex_array::dtype::NativePType;
use vortex_array::match_each_unsigned_integer_ptype;
use vortex_array::validity::Validity;
use vortex_buffer::BitBuffer;
use vortex_buffer::buffer_mut;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_mask::Mask;
use crate::RunEnd;
use crate::array::RunEndArrayExt;
use crate::compute::take::take_indices_unchecked;
const FILTER_TAKE_THRESHOLD: f64 = 0.1;
impl FilterKernel for RunEnd {
fn filter(
array: ArrayView<'_, Self>,
mask: &Mask,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let mask_values = mask
.values()
.vortex_expect("FilterKernel precondition: mask is Mask::Values");
let runs_ratio = mask_values.true_count() as f64 / array.ends().len() as f64;
if runs_ratio < FILTER_TAKE_THRESHOLD || mask_values.true_count() < 25 {
Ok(Some(take_indices_unchecked(
array,
mask_values.indices(),
&Validity::NonNullable,
ctx,
)?))
} else {
let primitive_run_ends = array.ends().clone().execute::<PrimitiveArray>(ctx)?;
let (run_ends, values_mask) =
match_each_unsigned_integer_ptype!(primitive_run_ends.ptype(), |P| {
filter_run_end_primitive(
primitive_run_ends.as_slice::<P>(),
array.offset() as u64,
array.len() as u64,
mask_values.bit_buffer(),
)?
});
let values = array.values().filter(values_mask)?;
unsafe {
Ok(Some(
RunEnd::new_unchecked(
run_ends.into_array(),
values,
0,
mask_values.true_count(),
)
.into_array(),
))
}
}
}
}
fn filter_run_end_primitive<R: NativePType + AddAssign + From<bool> + AsPrimitive<u64>>(
run_ends: &[R],
offset: u64,
length: u64,
mask: &BitBuffer,
) -> VortexResult<(PrimitiveArray, Mask)> {
let mut new_run_ends = buffer_mut![R::zero(); run_ends.len()];
let mut start = 0u64;
let mut j = 0;
let mut count = R::zero();
let new_mask: Mask = BitBuffer::collect_bool(run_ends.len(), |i| {
let mut keep = false;
let end = min(run_ends[i].as_() - offset, length);
for pred in (start..end).map(|i| unsafe {
mask.value_unchecked(i.try_into().vortex_expect("index must fit in usize"))
}) {
count += <R as From<bool>>::from(pred);
keep |= pred
}
new_run_ends[j] = count;
j += keep as usize;
start = end;
keep
})
.into();
new_run_ends.truncate(j);
Ok((
PrimitiveArray::new(new_run_ends, Validity::NonNullable),
new_mask,
))
}
#[cfg(test)]
mod tests {
use vortex_array::IntoArray;
use vortex_array::LEGACY_SESSION;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::assert_arrays_eq;
use vortex_error::VortexResult;
use vortex_mask::Mask;
use crate::RunEnd;
use crate::RunEndArray;
fn ree_array() -> RunEndArray {
RunEnd::encode(
PrimitiveArray::from_iter([1, 1, 1, 4, 4, 4, 2, 2, 5, 5, 5, 5]).into_array(),
&mut LEGACY_SESSION.create_execution_ctx(),
)
.unwrap()
}
#[test]
fn filter_sliced_run_end() -> VortexResult<()> {
let arr = ree_array().slice(2..7)?;
let filtered = arr.filter(Mask::from_iter([true, false, false, true, true]))?;
let mut ctx = LEGACY_SESSION.create_execution_ctx();
assert_arrays_eq!(
filtered,
RunEnd::new(
PrimitiveArray::from_iter([1u8, 2, 3]).into_array(),
PrimitiveArray::from_iter([1i32, 4, 2]).into_array(),
&mut ctx,
)
);
Ok(())
}
#[test]
fn filter_sliced_run_end_preserves_encoding() -> VortexResult<()> {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let values: Vec<i32> = [10, 20, 30, 40]
.iter()
.flat_map(|&v| std::iter::repeat_n(v, 32))
.collect();
let arr = RunEnd::encode(PrimitiveArray::from_iter(values).into_array(), &mut ctx)?;
let sliced = arr.into_array().slice(16..128)?;
let mask = Mask::from_iter((0..sliced.len()).map(|i| i % 2 == 0));
let filtered = sliced.filter(mask)?;
let executed = filtered.execute_until::<RunEnd>(&mut ctx)?;
assert_eq!(
executed.encoding_id().as_ref(),
"vortex.runend",
"Filter(Slice(RunEnd)) should preserve RunEnd encoding"
);
let expected: Vec<i32> = std::iter::repeat_n(10, 8)
.chain(std::iter::repeat_n(20, 16))
.chain(std::iter::repeat_n(30, 16))
.chain(std::iter::repeat_n(40, 16))
.collect();
assert_arrays_eq!(executed, PrimitiveArray::from_iter(expected));
Ok(())
}
}