use vortex_error::VortexResult;
use super::Dict;
use super::DictArray;
use crate::ArrayRef;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::dict::DictArrayExt;
use crate::arrays::dict::DictArraySlotsExt;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::scalar_fn::fns::cast::CastReduce;
impl CastReduce for Dict {
fn cast(array: ArrayView<'_, Dict>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
if !dtype.is_nullable()
&& array.values().dtype().is_nullable()
&& !array.values().all_valid()?
{
return Ok(None);
}
let casted_values = array.values().cast(dtype.clone())?;
let casted_codes = if array.codes().dtype().is_nullable() && !dtype.is_nullable() {
array
.codes()
.cast(array.codes().dtype().with_nullability(dtype.nullability()))?
} else {
array.codes().clone()
};
Ok(Some(
unsafe {
DictArray::new_unchecked(casted_codes, casted_values)
.set_all_values_referenced(array.has_all_values_referenced())
}
.into_array(),
))
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use vortex_buffer::buffer;
use crate::IntoArray;
use crate::ToCanonical;
use crate::arrays::Dict;
use crate::arrays::PrimitiveArray;
use crate::arrays::dict::DictArraySlotsExt;
use crate::assert_arrays_eq;
use crate::builders::dict::dict_encode;
use crate::builtins::ArrayBuiltins;
use crate::compute::conformance::cast::test_cast_conformance;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
#[test]
fn test_cast_dict_to_wider_type() {
let values = buffer![1i32, 2, 3, 2, 1].into_array();
let dict = dict_encode(&values).unwrap();
let casted = dict
.into_array()
.cast(DType::Primitive(PType::I64, Nullability::NonNullable))
.unwrap();
assert_eq!(
casted.dtype(),
&DType::Primitive(PType::I64, Nullability::NonNullable)
);
let decoded = casted.to_primitive();
assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i64, 2, 3, 2, 1]));
}
#[test]
fn test_cast_dict_nullable() {
let values =
PrimitiveArray::from_option_iter([Some(10i32), None, Some(20), Some(10), None]);
let dict = dict_encode(&values.into_array()).unwrap();
let casted = dict
.into_array()
.cast(DType::Primitive(PType::I64, Nullability::Nullable))
.unwrap();
assert_eq!(
casted.dtype(),
&DType::Primitive(PType::I64, Nullability::Nullable)
);
}
#[test]
fn test_cast_dict_allvalid_to_nonnullable_and_back() {
let values = buffer![10i32, 20, 30, 40].into_array();
let dict = dict_encode(&values).unwrap();
assert_eq!(dict.codes().dtype().nullability(), Nullability::NonNullable);
assert_eq!(
dict.values().dtype().nullability(),
Nullability::NonNullable
);
let non_nullable = dict
.clone()
.into_array()
.cast(DType::Primitive(PType::I32, Nullability::NonNullable))
.unwrap();
assert_eq!(
non_nullable.dtype(),
&DType::Primitive(PType::I32, Nullability::NonNullable)
);
let non_nullable_dict = non_nullable.as_::<Dict>();
assert_eq!(
non_nullable_dict.codes().dtype().nullability(),
Nullability::NonNullable
);
assert_eq!(
non_nullable_dict.values().dtype().nullability(),
Nullability::NonNullable
);
let nullable = non_nullable
.cast(DType::Primitive(PType::I32, Nullability::Nullable))
.unwrap();
assert_eq!(
nullable.dtype(),
&DType::Primitive(PType::I32, Nullability::Nullable)
);
let nullable_dict = nullable.as_::<Dict>();
assert_eq!(
nullable_dict.codes().dtype().nullability(),
Nullability::NonNullable
);
assert_eq!(
nullable_dict.values().dtype().nullability(),
Nullability::Nullable
);
let back_to_non_nullable = nullable
.cast(DType::Primitive(PType::I32, Nullability::NonNullable))
.unwrap();
assert_eq!(
back_to_non_nullable.dtype(),
&DType::Primitive(PType::I32, Nullability::NonNullable)
);
let back_dict = back_to_non_nullable.as_::<Dict>();
assert_eq!(
back_dict.codes().dtype().nullability(),
Nullability::NonNullable
);
assert_eq!(
back_dict.values().dtype().nullability(),
Nullability::NonNullable
);
let original_values = dict.as_array().to_primitive();
let final_values = back_dict.array().to_primitive();
assert_arrays_eq!(original_values, final_values);
}
#[rstest]
#[case(dict_encode(&buffer![1i32, 2, 3, 2, 1, 3].into_array()).unwrap().into_array())]
#[case(dict_encode(&buffer![100u32, 200, 100, 300, 200].into_array()).unwrap().into_array())]
#[case(dict_encode(&PrimitiveArray::from_option_iter([Some(1i32), None, Some(2), Some(1), None]).into_array()).unwrap().into_array())]
#[case(dict_encode(&buffer![1.5f32, 2.5, 1.5, 3.5].into_array()).unwrap().into_array())]
fn test_cast_dict_conformance(#[case] array: crate::ArrayRef) {
test_cast_conformance(&array);
}
#[test]
fn test_cast_dict_with_unreferenced_null_values_to_nonnullable() {
use crate::arrays::DictArray;
use crate::validity::Validity;
let values = PrimitiveArray::new(
buffer![1.0f64, 0.0f64, 3.0f64],
Validity::from(vortex_buffer::BitBuffer::from(vec![true, false, true])),
)
.into_array();
let codes = buffer![0u32, 2, 0].into_array();
let dict = DictArray::try_new(codes, values).unwrap();
assert_eq!(
dict.dtype(),
&DType::Primitive(PType::F64, Nullability::Nullable)
);
let result = dict
.into_array()
.cast(DType::Primitive(PType::F64, Nullability::NonNullable));
assert!(
result.is_ok(),
"cast to NonNullable should succeed for dict with only unreferenced null values"
);
let casted = result.unwrap();
assert_eq!(
casted.dtype(),
&DType::Primitive(PType::F64, Nullability::NonNullable)
);
assert_arrays_eq!(
casted.to_primitive(),
PrimitiveArray::from_iter([1.0f64, 3.0, 1.0])
);
}
}