use vortex_buffer::Buffer;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_panic;
use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::DecimalArray;
use crate::arrays::DecimalVTable;
use crate::dtype::DType;
use crate::dtype::DecimalType;
use crate::dtype::NativeDecimalType;
use crate::match_each_decimal_value_type;
use crate::scalar_fn::fns::cast::CastKernel;
use crate::vtable::ValidityHelper;
impl CastKernel for DecimalVTable {
fn cast(
array: &DecimalArray,
dtype: &DType,
_ctx: &mut ExecutionCtx,
) -> VortexResult<Option<ArrayRef>> {
let DType::Decimal(to_decimal_dtype, to_nullability) = dtype else {
return Ok(None);
};
let DType::Decimal(from_decimal_dtype, _) = array.dtype() else {
vortex_panic!(
"DecimalArray must have decimal dtype, got {:?}",
array.dtype()
);
};
if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
vortex_bail!(
"Casting decimal with scale {} to scale {} not yet implemented",
from_decimal_dtype.scale(),
to_decimal_dtype.scale()
);
}
if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
vortex_bail!(
"Downcasting decimal from precision {} to {} not yet implemented",
from_decimal_dtype.precision(),
to_decimal_dtype.precision()
);
}
if array.dtype() == dtype {
return Ok(Some(array.clone().into_array()));
}
let new_validity = array
.validity()
.clone()
.cast_nullability(*to_nullability, array.len())?;
let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
let array = if target_values_type > array.values_type() {
upcast_decimal_values(array, target_values_type)?
} else {
array.clone()
};
unsafe {
Ok(Some(
DecimalArray::new_unchecked_handle(
array.buffer_handle().clone(),
array.values_type(),
*to_decimal_dtype,
new_validity,
)
.into_array(),
))
}
}
}
pub fn upcast_decimal_values(
array: &DecimalArray,
to_values_type: DecimalType,
) -> VortexResult<DecimalArray> {
let from_values_type = array.values_type();
if from_values_type == to_values_type {
return Ok(array.clone());
}
if to_values_type < from_values_type {
vortex_bail!(
"Cannot downcast decimal values from {:?} to {:?}. Only upcasting is supported.",
from_values_type,
to_values_type
);
}
let decimal_dtype = array.decimal_dtype();
let validity = array.validity().clone();
match_each_decimal_value_type!(from_values_type, |F| {
let from_buffer = array.buffer::<F>();
match_each_decimal_value_type!(to_values_type, |T| {
let to_buffer = upcast_decimal_buffer::<F, T>(from_buffer);
Ok(DecimalArray::new(to_buffer, decimal_dtype, validity))
})
})
}
fn upcast_decimal_buffer<F: NativeDecimalType, T: NativeDecimalType>(from: Buffer<F>) -> Buffer<T> {
from.iter()
.map(|&v| T::from(v).vortex_expect("upcast should never fail"))
.collect()
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_buffer::buffer;
use super::upcast_decimal_values;
use crate::IntoArray;
use crate::arrays::DecimalArray;
use crate::builtins::ArrayBuiltins;
use crate::canonical::ToCanonical;
use crate::compute::conformance::cast::test_cast_conformance;
use crate::dtype::DType;
use crate::dtype::DecimalDType;
use crate::dtype::DecimalType;
use crate::dtype::Nullability;
use crate::validity::Validity;
use crate::vtable::ValidityHelper;
#[test]
fn cast_decimal_to_nullable() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::new(
buffer![100i32, 200, 300],
decimal_dtype,
Validity::NonNullable,
);
let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
let casted = array
.into_array()
.cast(nullable_dtype.clone())
.unwrap()
.to_decimal();
assert_eq!(casted.dtype(), &nullable_dtype);
assert_eq!(casted.validity(), &Validity::AllValid);
assert_eq!(casted.len(), 3);
}
#[test]
fn cast_nullable_to_non_nullable() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
let casted = array
.into_array()
.cast(non_nullable_dtype.clone())
.unwrap()
.to_decimal();
assert_eq!(casted.dtype(), &non_nullable_dtype);
assert_eq!(casted.validity(), &Validity::NonNullable);
}
#[test]
#[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
fn cast_nullable_with_nulls_to_non_nullable_fails() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
array
.into_array()
.cast(non_nullable_dtype)
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
.unwrap();
}
#[test]
fn cast_different_scale_fails() {
let array = DecimalArray::new(
buffer![100i32],
DecimalDType::new(10, 2),
Validity::NonNullable,
);
let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
let result = array
.into_array()
.cast(different_dtype)
.and_then(|a| a.to_canonical().map(|c| c.into_array()));
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Casting decimal with scale 2 to scale 3 not yet implemented")
);
}
#[test]
fn cast_downcast_precision_fails() {
let array = DecimalArray::new(
buffer![100i64],
DecimalDType::new(18, 2),
Validity::NonNullable,
);
let smaller_dtype = DType::Decimal(DecimalDType::new(10, 2), Nullability::NonNullable);
let result = array
.into_array()
.cast(smaller_dtype)
.and_then(|a| a.to_canonical().map(|c| c.into_array()));
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Downcasting decimal from precision 18 to 10 not yet implemented")
);
}
#[test]
fn cast_upcast_precision_succeeds() {
let array = DecimalArray::new(
buffer![100i32, 200, 300],
DecimalDType::new(10, 2),
Validity::NonNullable,
);
let wider_dtype = DType::Decimal(DecimalDType::new(38, 2), Nullability::NonNullable);
let casted = array.into_array().cast(wider_dtype).unwrap().to_decimal();
assert_eq!(casted.precision(), 38);
assert_eq!(casted.scale(), 2);
assert_eq!(casted.len(), 3);
assert_eq!(casted.values_type(), DecimalType::I128);
}
#[test]
fn cast_to_non_decimal_returns_err() {
let array = DecimalArray::new(
buffer![100i32],
DecimalDType::new(10, 2),
Validity::NonNullable,
);
let result = array
.into_array()
.cast(DType::Utf8(Nullability::NonNullable))
.and_then(|a| a.to_canonical().map(|c| c.into_array()));
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("No CastKernel to cast canonical array")
);
}
#[rstest]
#[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
#[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
#[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
#[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
test_cast_conformance(&array.into_array());
}
#[test]
fn upcast_decimal_values_i32_to_i64() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::new(
buffer![100i32, 200, 300],
decimal_dtype,
Validity::NonNullable,
);
assert_eq!(array.values_type(), DecimalType::I32);
let casted = upcast_decimal_values(&array, DecimalType::I64).unwrap();
assert_eq!(casted.values_type(), DecimalType::I64);
assert_eq!(casted.decimal_dtype(), decimal_dtype);
assert_eq!(casted.len(), 3);
let buffer = casted.buffer::<i64>();
assert_eq!(buffer.as_ref(), &[100i64, 200, 300]);
}
#[test]
fn upcast_decimal_values_i64_to_i128() {
let decimal_dtype = DecimalDType::new(18, 4);
let array = DecimalArray::new(
buffer![10000i64, 20000, 30000],
decimal_dtype,
Validity::NonNullable,
);
let casted = upcast_decimal_values(&array, DecimalType::I128).unwrap();
assert_eq!(casted.values_type(), DecimalType::I128);
assert_eq!(casted.decimal_dtype(), decimal_dtype);
let buffer = casted.buffer::<i128>();
assert_eq!(buffer.as_ref(), &[10000i128, 20000, 30000]);
}
#[test]
fn upcast_decimal_values_same_type_returns_clone() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::new(
buffer![100i32, 200, 300],
decimal_dtype,
Validity::NonNullable,
);
let casted = upcast_decimal_values(&array, DecimalType::I32).unwrap();
assert_eq!(casted.values_type(), DecimalType::I32);
assert_eq!(casted.decimal_dtype(), decimal_dtype);
}
#[test]
fn upcast_decimal_values_with_nulls() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
let casted = upcast_decimal_values(&array, DecimalType::I64).unwrap();
assert_eq!(casted.values_type(), DecimalType::I64);
assert_eq!(casted.len(), 3);
let mask = casted.validity_mask().unwrap();
assert!(mask.value(0));
assert!(!mask.value(1));
assert!(mask.value(2));
let buffer = casted.buffer::<i64>();
assert_eq!(buffer[0], 100);
assert_eq!(buffer[2], 300);
}
#[test]
fn upcast_decimal_values_downcast_fails() {
let decimal_dtype = DecimalDType::new(18, 4);
let array = DecimalArray::new(
buffer![10000i64, 20000, 30000],
decimal_dtype,
Validity::NonNullable,
);
let result = upcast_decimal_values(&array, DecimalType::I32);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Cannot downcast decimal values")
);
}
}