use vortex_buffer::{Buffer, BufferMut};
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
use vortex_error::{VortexResult, vortex_err};
use vortex_mask::{AllOr, Mask};
use crate::arrays::PrimitiveVTable;
use crate::arrays::primitive::PrimitiveArray;
use crate::compute::{CastKernel, CastKernelAdapter};
use crate::vtable::ValidityHelper;
use crate::{ArrayRef, IntoArray, register_kernel};
impl CastKernel for PrimitiveVTable {
fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
let DType::Primitive(new_ptype, new_nullability) = dtype else {
return Ok(None);
};
let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
let new_validity = array
.validity()
.clone()
.cast_nullability(new_nullability, array.len())?;
if array.ptype() == new_ptype {
return Ok(Some(
PrimitiveArray::from_byte_buffer(
array.byte_buffer().clone(),
array.ptype(),
new_validity,
)
.into_array(),
));
}
let mask = array.validity_mask();
Ok(Some(match_each_native_ptype!(new_ptype, |T| {
match_each_native_ptype!(array.ptype(), |F| {
PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
.into_array()
})
})))
}
}
register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
match mask.boolean_buffer() {
AllOr::All => {
let mut buffer = BufferMut::with_capacity(array.len());
for item in array {
let item = T::from(*item).ok_or_else(
|| vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
)?;
unsafe { buffer.push_unchecked(item) }
}
Ok(buffer.freeze())
}
AllOr::None => Ok(Buffer::zeroed(array.len())),
AllOr::Some(b) => {
let mut buffer = BufferMut::with_capacity(array.len());
for (item, valid) in array.iter().zip(b.iter()) {
if valid {
let item = T::from(*item).ok_or_else(
|| vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
)?;
unsafe { buffer.push_unchecked(item) }
} else {
unsafe { buffer.push_unchecked(T::default()) }
}
}
Ok(buffer.freeze())
}
}
}
#[cfg(test)]
mod test {
use arrow_buffer::BooleanBuffer;
use rstest::rstest;
use vortex_buffer::buffer;
use vortex_dtype::{DType, Nullability, PType};
use vortex_error::VortexError;
use vortex_mask::Mask;
use crate::IntoArray;
use crate::arrays::PrimitiveArray;
use crate::canonical::ToCanonical;
use crate::compute::cast;
use crate::compute::conformance::cast::test_cast_conformance;
use crate::validity::Validity;
use crate::vtable::ValidityHelper;
#[test]
fn cast_u32_u8() {
let arr = buffer![0u32, 10, 200].into_array();
let p = cast(&arr, PType::U8.into()).unwrap().to_primitive();
assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
assert_eq!(p.validity(), &Validity::NonNullable);
let p = cast(
p.as_ref(),
&DType::Primitive(PType::U8, Nullability::Nullable),
)
.unwrap()
.to_primitive();
assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
assert_eq!(p.validity(), &Validity::AllValid);
let p = cast(
p.as_ref(),
&DType::Primitive(PType::U8, Nullability::NonNullable),
)
.unwrap()
.to_primitive();
assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
assert_eq!(p.validity(), &Validity::NonNullable);
let p = cast(
p.as_ref(),
&DType::Primitive(PType::U32, Nullability::Nullable),
)
.unwrap()
.to_primitive();
assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
assert_eq!(p.validity(), &Validity::AllValid);
let p = cast(
p.as_ref(),
&DType::Primitive(PType::U8, Nullability::NonNullable),
)
.unwrap()
.to_primitive();
assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
assert_eq!(p.validity(), &Validity::NonNullable);
}
#[test]
fn cast_u32_f32() {
let arr = buffer![0u32, 10, 200].into_array();
let u8arr = cast(&arr, PType::F32.into()).unwrap().to_primitive();
assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
}
#[test]
fn cast_i32_u32() {
let arr = buffer![-1i32].into_array();
let error = cast(&arr, PType::U32.into()).err().unwrap();
let VortexError::ComputeError(s, _) = error else {
unreachable!()
};
assert_eq!(s.to_string(), "Failed to cast -1 to U32");
}
#[test]
fn cast_array_with_nulls_to_nonnullable() {
let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
let VortexError::InvalidArgument(s, _) = err else {
unreachable!()
};
assert_eq!(
s.to_string(),
"Cannot cast array with invalid values to non-nullable type."
);
}
#[test]
fn cast_with_invalid_nulls() {
let arr = PrimitiveArray::new(
buffer![-1i32, 0, 10],
Validity::from_iter([false, true, true]),
);
let p = cast(
arr.as_ref(),
&DType::Primitive(PType::U32, Nullability::Nullable),
)
.unwrap()
.to_primitive();
assert_eq!(p.as_slice::<u32>(), vec![0, 0, 10]);
assert_eq!(
p.validity_mask(),
Mask::from(BooleanBuffer::from(vec![false, true, true]))
);
}
#[rstest]
#[case(buffer![0u8, 1, 2, 3, 255].into_array())]
#[case(buffer![0u16, 100, 1000, 65535].into_array())]
#[case(buffer![0u32, 100, 1000, 1000000].into_array())]
#[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
#[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
#[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
#[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
#[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
#[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
#[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
#[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
#[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
#[case(buffer![42u32].into_array())]
fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
test_cast_conformance(array.as_ref());
}
}