use itertools::Itertools as _;
use num_traits::AsPrimitive;
use vortex_buffer::BitBuffer;
use vortex_buffer::get_bit;
use vortex_error::VortexResult;
use vortex_mask::Mask;
use crate::ArrayRef;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::Bool;
use crate::arrays::BoolArray;
use crate::arrays::ConstantArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::bool::BoolArrayExt;
use crate::arrays::dict::TakeExecute;
use crate::builtins::ArrayBuiltins;
use crate::executor::ExecutionCtx;
use crate::match_each_integer_ptype;
use crate::scalar::Scalar;
impl TakeExecute for Bool {
fn take(
array: ArrayView<'_, Bool>,
indices: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let indices_nulls_zeroed = match indices.validity_mask()? {
Mask::AllTrue(_) => indices.clone(),
Mask::AllFalse(_) => {
return Ok(Some(
ConstantArray::new(Scalar::null(array.dtype().as_nullable()), indices.len())
.into_array(),
));
}
Mask::Values(_) => indices
.clone()
.fill_null(Scalar::from(0).cast(indices.dtype())?)?,
};
let indices_nulls_zeroed = indices_nulls_zeroed.execute::<PrimitiveArray>(ctx)?;
let buffer = match_each_integer_ptype!(indices_nulls_zeroed.ptype(), |I| {
take_valid_indices(&array.to_bit_buffer(), indices_nulls_zeroed.as_slice::<I>())
});
Ok(Some(
BoolArray::new(buffer, array.validity()?.take(indices)?).into_array(),
))
}
}
fn take_valid_indices<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
if bools.len() <= 4096 {
let bools = bools.iter().collect_vec();
take_byte_bool(bools, indices)
} else {
take_bool_impl(bools, indices)
}
}
fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BitBuffer {
BitBuffer::collect_bool(indices.len(), |idx| {
bools[unsafe { indices.get_unchecked(idx).as_() }]
})
}
fn take_bool_impl<I: AsPrimitive<usize>>(bools: &BitBuffer, indices: &[I]) -> BitBuffer {
let buffer = bools.inner().as_ref();
BitBuffer::collect_bool(indices.len(), |idx| {
let idx = unsafe { indices.get_unchecked(idx).as_() };
get_bit(buffer, bools.offset() + idx)
})
}
#[cfg(test)]
mod test {
use rstest::rstest;
use vortex_buffer::buffer;
use crate::IntoArray as _;
use crate::ToCanonical;
use crate::arrays::BoolArray;
use crate::arrays::PrimitiveArray;
use crate::arrays::bool::BoolArrayExt;
use crate::assert_arrays_eq;
use crate::compute::conformance::take::test_take_conformance;
use crate::validity::Validity;
#[test]
fn take_nullable() {
let reference = BoolArray::from_iter(vec![
Some(false),
Some(true),
Some(false),
None,
Some(false),
]);
let b = reference
.take(buffer![0, 3, 4].into_array())
.unwrap()
.to_bool();
assert_eq!(
b.to_bit_buffer(),
BoolArray::from_iter([Some(false), None, Some(false)]).to_bit_buffer()
);
let all_invalid_indices = PrimitiveArray::from_option_iter([None::<i32>, None, None]);
let b = reference.take(all_invalid_indices.into_array()).unwrap();
assert_arrays_eq!(b, BoolArray::from_iter([None, None, None]));
}
#[test]
fn test_bool_array_take_with_null_out_of_bounds_indices() {
let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
let indices = PrimitiveArray::new(
buffer![0, 3, 100],
Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
);
let actual = values.take(indices.into_array()).unwrap();
assert_arrays_eq!(actual, BoolArray::from_iter([Some(false), None, None]));
}
#[test]
fn test_non_null_bool_array_take_with_null_out_of_bounds_indices() {
let values = BoolArray::from_iter(vec![false, true, false, true, false]);
let indices = PrimitiveArray::new(
buffer![0, 3, 100],
Validity::Array(BoolArray::from_iter([true, true, false]).into_array()),
);
let actual = values.take(indices.into_array()).unwrap();
assert_arrays_eq!(
actual,
BoolArray::from_iter([Some(false), Some(true), None])
);
}
#[test]
fn test_bool_array_take_all_null_indices() {
let values = BoolArray::from_iter(vec![Some(false), Some(true), None, None, Some(false)]);
let indices = PrimitiveArray::new(
buffer![0, 3, 100],
Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
);
let actual = values.take(indices.into_array()).unwrap();
assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
}
#[test]
fn test_non_null_bool_array_take_all_null_indices() {
let values = BoolArray::from_iter(vec![false, true, false, true, false]);
let indices = PrimitiveArray::new(
buffer![0, 3, 100],
Validity::Array(BoolArray::from_iter([false, false, false]).into_array()),
);
let actual = values.take(indices.into_array()).unwrap();
assert_arrays_eq!(actual, BoolArray::from_iter([None, None, None]));
}
#[rstest]
#[case(BoolArray::from_iter([true, false, true, true, false]))]
#[case(BoolArray::from_iter([Some(true), None, Some(false), Some(true), None]))]
#[case(BoolArray::from_iter([true, false]))]
#[case(BoolArray::from_iter([true]))]
fn test_take_bool_conformance(#[case] array: BoolArray) {
test_take_conformance(&array.into_array());
}
}