use num_traits::AsPrimitive;
use vortex_array::ArrayRef;
use vortex_array::ExecutionCtx;
use vortex_array::IntoArray;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::arrays::dict::TakeExecute;
use vortex_array::dtype::DType;
use vortex_array::match_each_integer_ptype;
use vortex_array::scalar_fn::fns::cast::CastReduce;
use vortex_array::scalar_fn::fns::mask::MaskReduce;
use vortex_array::validity::Validity;
use vortex_array::vtable::ValidityHelper;
use vortex_error::VortexResult;
use super::ByteBool;
use super::ByteBoolArray;
impl CastReduce for ByteBool {
fn cast(array: &ByteBoolArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
if array.dtype().eq_ignore_nullability(dtype) {
let new_validity = array
.validity()
.clone()
.cast_nullability(dtype.nullability(), array.len())?;
return Ok(Some(
ByteBoolArray::new(array.buffer().clone(), new_validity).into_array(),
));
}
Ok(None)
}
}
impl MaskReduce for ByteBool {
fn mask(array: &ByteBoolArray, mask: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
Ok(Some(
ByteBoolArray::new(
array.buffer().clone(),
array
.validity()
.clone()
.and(Validity::Array(mask.clone()))?,
)
.into_array(),
))
}
}
impl TakeExecute for ByteBool {
fn take(
array: &ByteBoolArray,
indices: &ArrayRef,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let indices = indices.clone().execute::<PrimitiveArray>(ctx)?;
let bools = array.as_slice();
let validity = array.validity().take(&indices.clone().into_array())?;
let taken_bools = match_each_integer_ptype!(indices.ptype(), |I| {
indices
.as_slice::<I>()
.iter()
.map(|&idx| {
let idx: usize = idx.as_();
bools[idx]
})
.collect::<Vec<bool>>()
});
Ok(Some(
ByteBoolArray::from_vec(taken_bools, validity).into_array(),
))
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_array::assert_arrays_eq;
use vortex_array::builtins::ArrayBuiltins;
use vortex_array::compute::conformance::cast::test_cast_conformance;
use vortex_array::compute::conformance::consistency::test_array_consistency;
use vortex_array::compute::conformance::filter::test_filter_conformance;
use vortex_array::compute::conformance::mask::test_mask_conformance;
use vortex_array::compute::conformance::take::test_take_conformance;
use vortex_array::dtype::DType;
use vortex_array::dtype::Nullability;
use vortex_array::scalar_fn::fns::operators::Operator;
use super::*;
#[test]
fn test_slice() {
let original = vec![Some(true), Some(true), None, Some(false), None];
let vortex_arr = ByteBoolArray::from(original);
let sliced_arr = vortex_arr.slice(1..4).unwrap();
let expected = ByteBoolArray::from(vec![Some(true), None, Some(false)]);
assert_arrays_eq!(sliced_arr, expected.into_array());
}
#[test]
fn test_compare_all_equal() {
let lhs = ByteBoolArray::from(vec![true; 5]);
let rhs = ByteBoolArray::from(vec![true; 5]);
let arr = lhs
.into_array()
.binary(rhs.into_array(), Operator::Eq)
.unwrap();
let expected = ByteBoolArray::from(vec![true; 5]);
assert_arrays_eq!(arr, expected.into_array());
}
#[test]
fn test_compare_all_different() {
let lhs = ByteBoolArray::from(vec![false; 5]);
let rhs = ByteBoolArray::from(vec![true; 5]);
let arr = lhs
.into_array()
.binary(rhs.into_array(), Operator::Eq)
.unwrap();
let expected = ByteBoolArray::from(vec![false; 5]);
assert_arrays_eq!(arr, expected.into_array());
}
#[test]
fn test_compare_with_nulls() {
let lhs = ByteBoolArray::from(vec![true; 5]);
let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
let arr = lhs
.into_array()
.binary(rhs.into_array(), Operator::Eq)
.unwrap();
let expected =
ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);
assert_arrays_eq!(arr, expected.into_array());
}
#[test]
fn test_mask_byte_bool() {
test_mask_conformance(
&ByteBoolArray::from(vec![true, false, true, true, false]).into_array(),
);
test_mask_conformance(
&ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None])
.into_array(),
);
}
#[test]
fn test_filter_byte_bool() {
test_filter_conformance(
&ByteBoolArray::from(vec![true, false, true, true, false]).into_array(),
);
test_filter_conformance(
&ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None])
.into_array(),
);
}
#[rstest]
#[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
#[case(ByteBoolArray::from(vec![Some(true), Some(true), None, Some(false), None]))]
#[case(ByteBoolArray::from(vec![true, false]))]
#[case(ByteBoolArray::from(vec![true]))]
fn test_take_byte_bool_conformance(#[case] array: ByteBoolArray) {
test_take_conformance(&array.into_array());
}
#[test]
fn test_cast_bytebool_to_nullable() {
let array = ByteBoolArray::from(vec![true, false, true, false]);
let casted = array
.into_array()
.cast(DType::Bool(Nullability::Nullable))
.unwrap();
assert_eq!(casted.dtype(), &DType::Bool(Nullability::Nullable));
assert_eq!(casted.len(), 4);
}
#[rstest]
#[case(ByteBoolArray::from(vec![true, false, true, true, false]))]
#[case(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
#[case(ByteBoolArray::from(vec![false]))]
#[case(ByteBoolArray::from(vec![true]))]
#[case(ByteBoolArray::from(vec![Some(true), None]))]
fn test_cast_bytebool_conformance(#[case] array: ByteBoolArray) {
test_cast_conformance(&array.into_array());
}
#[rstest]
#[case::non_nullable(ByteBoolArray::from(vec![true, false, true, true, false]))]
#[case::nullable(ByteBoolArray::from(vec![Some(true), Some(false), None, Some(true), None]))]
#[case::all_true(ByteBoolArray::from(vec![true, true, true, true]))]
#[case::all_false(ByteBoolArray::from(vec![false, false, false, false]))]
#[case::single_true(ByteBoolArray::from(vec![true]))]
#[case::single_false(ByteBoolArray::from(vec![false]))]
#[case::single_null(ByteBoolArray::from(vec![None]))]
#[case::mixed_with_nulls(ByteBoolArray::from(vec![Some(true), None, Some(false), None, Some(true)]))]
fn test_bytebool_consistency(#[case] array: ByteBoolArray) {
test_array_consistency(&array.into_array());
}
}