use vortex_dtype::DType;
use vortex_error::{VortexResult, vortex_bail, vortex_panic};
use crate::arrays::{DecimalArray, DecimalVTable};
use crate::compute::{CastKernel, CastKernelAdapter};
use crate::stats::ArrayStats;
use crate::vtable::ValidityHelper;
use crate::{ArrayRef, register_kernel};
impl CastKernel for DecimalVTable {
fn cast(&self, array: &DecimalArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
let DType::Decimal(to_precision_scale, to_nullability) = dtype else {
return Ok(None);
};
let DType::Decimal(from_precision_scale, _) = array.dtype() else {
vortex_panic!(
"DecimalArray must have decimal dtype, got {:?}",
array.dtype()
);
};
if from_precision_scale != to_precision_scale {
vortex_bail!(
"Cannot cast decimal({},{}) to decimal({},{})",
from_precision_scale.precision(),
from_precision_scale.scale(),
to_precision_scale.precision(),
to_precision_scale.scale()
);
}
if array.dtype() == dtype {
return Ok(Some(array.to_array()));
}
let new_validity = array
.validity()
.clone()
.cast_nullability(*to_nullability, array.len())?;
Ok(Some(
DecimalArray {
dtype: DType::Decimal(*from_precision_scale, *to_nullability),
values: array.byte_buffer(),
values_type: array.values_type(),
validity: new_validity,
stats_set: ArrayStats::default(),
}
.to_array(),
))
}
}
register_kernel!(CastKernelAdapter(DecimalVTable).lift());
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_buffer::buffer;
use vortex_dtype::{DType, DecimalDType, Nullability};
use crate::arrays::DecimalArray;
use crate::canonical::ToCanonical;
use crate::compute::cast;
use crate::compute::conformance::cast::test_cast_conformance;
use crate::validity::Validity;
use crate::vtable::ValidityHelper;
#[test]
fn cast_decimal_to_nullable() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::new(
buffer![100i32, 200, 300],
decimal_dtype,
Validity::NonNullable,
);
let nullable_dtype = DType::Decimal(decimal_dtype, Nullability::Nullable);
let casted = cast(array.as_ref(), &nullable_dtype).unwrap().to_decimal();
assert_eq!(casted.dtype(), &nullable_dtype);
assert_eq!(casted.validity(), &Validity::AllValid);
assert_eq!(casted.len(), 3);
}
#[test]
fn cast_nullable_to_non_nullable() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::new(buffer![100i32, 200, 300], decimal_dtype, Validity::AllValid);
let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
let casted = cast(array.as_ref(), &non_nullable_dtype)
.unwrap()
.to_decimal();
assert_eq!(casted.dtype(), &non_nullable_dtype);
assert_eq!(casted.validity(), &Validity::NonNullable);
}
#[test]
#[should_panic(expected = "Cannot cast array with invalid values to non-nullable type")]
fn cast_nullable_with_nulls_to_non_nullable_fails() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = DecimalArray::from_option_iter([Some(100i32), None, Some(300)], decimal_dtype);
let non_nullable_dtype = DType::Decimal(decimal_dtype, Nullability::NonNullable);
cast(array.as_ref(), &non_nullable_dtype).unwrap();
}
#[test]
fn cast_different_precision_fails() {
let array = DecimalArray::new(
buffer![100i32],
DecimalDType::new(10, 2),
Validity::NonNullable,
);
let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
let result = cast(array.as_ref(), &different_dtype);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Cannot cast decimal(10,2) to decimal(15,3)")
);
}
#[test]
fn cast_to_non_decimal_returns_err() {
let array = DecimalArray::new(
buffer![100i32],
DecimalDType::new(10, 2),
Validity::NonNullable,
);
let result = cast(array.as_ref(), &DType::Utf8(Nullability::NonNullable));
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("No compute kernel to cast")
);
}
#[rstest]
#[case(DecimalArray::new(buffer![100i32, 200, 300], DecimalDType::new(10, 2), Validity::NonNullable))]
#[case(DecimalArray::new(buffer![10000i64, 20000, 30000], DecimalDType::new(18, 4), Validity::NonNullable))]
#[case(DecimalArray::from_option_iter([Some(100i32), None, Some(300)], DecimalDType::new(10, 2)))]
#[case(DecimalArray::new(buffer![42i32], DecimalDType::new(5, 1), Validity::NonNullable))]
fn test_cast_decimal_conformance(#[case] array: DecimalArray) {
test_cast_conformance(array.as_ref());
}
}