use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::Bool;
use crate::arrays::BoolArray;
use crate::arrays::bool::BoolArrayExt;
use crate::dtype::DType;
use crate::scalar_fn::fns::cast::CastKernel;
use crate::scalar_fn::fns::cast::CastReduce;
impl CastReduce for Bool {
fn cast(array: ArrayView<'_, Bool>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
if !dtype.is_boolean() {
return Ok(None);
}
let Some(new_validity) = array
.validity()?
.trivial_cast_nullability(dtype.nullability(), array.len())?
else {
return Ok(None);
};
Ok(Some(
BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
))
}
}
impl CastKernel for Bool {
fn cast(
array: ArrayView<'_, Bool>,
dtype: &DType,
ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
if !dtype.is_boolean() {
return Ok(None);
}
let new_validity =
array
.validity()?
.cast_nullability(dtype.nullability(), array.len(), ctx)?;
Ok(Some(
BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
))
}
}
#[cfg(test)]
mod tests {
use std::sync::LazyLock;
use rstest::rstest;
use vortex_session::VortexSession;
use crate::Canonical;
use crate::IntoArray;
use crate::VortexSessionExecute;
use crate::arrays::BoolArray;
use crate::builtins::ArrayBuiltins;
use crate::compute::conformance::cast::test_cast_conformance;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::session::ArraySession;
static SESSION: LazyLock<VortexSession> =
LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
#[test]
fn try_cast_bool_success() {
let bool = BoolArray::from_iter(vec![Some(true), Some(false), Some(true)]);
let res = bool
.into_array()
.cast(DType::Bool(Nullability::NonNullable));
assert!(res.is_ok());
assert_eq!(res.unwrap().dtype(), &DType::Bool(Nullability::NonNullable));
}
#[test]
fn try_cast_bool_fail() {
let bool = BoolArray::from_iter(vec![Some(true), Some(false), None]);
let mut ctx = SESSION.create_execution_ctx();
let result = bool
.into_array()
.cast(DType::Bool(Nullability::NonNullable))
.and_then(|a| a.execute::<Canonical>(&mut ctx).map(|c| c.into_array()));
assert!(result.is_err(), "Expected error, got: {result:?}");
}
#[rstest]
#[case(BoolArray::from_iter(vec![true, false, true, true, false]))]
#[case(BoolArray::from_iter(vec![Some(true), Some(false), None, Some(true), None]))]
#[case(BoolArray::from_iter(vec![true]))]
#[case(BoolArray::from_iter(vec![false, false]))]
fn test_cast_bool_conformance(#[case] array: BoolArray) {
test_cast_conformance(&array.into_array());
}
}