#[cfg(test)]
mod tests {
use std::sync::Arc;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use crate::dtype::DType;
use crate::dtype::FieldDType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::dtype::extension::ExtDType;
use crate::dtype::extension::ExtId;
use crate::dtype::extension::ExtVTable;
use crate::dtype::half::f16;
use crate::scalar::PValue;
use crate::scalar::Scalar;
use crate::scalar::ScalarValue;
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
struct Apples;
impl ExtVTable for Apples {
type Metadata = usize;
type NativeValue<'a> = &'a str;
fn id(&self) -> ExtId {
ExtId::new_ref("apples")
}
fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
Ok(vec![])
}
fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult<Self::Metadata> {
Ok(0)
}
fn validate_dtype(_ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
Ok(())
}
fn unpack_native<'a>(
_ext_dtype: &'a ExtDType<Self>,
_storage_value: &'a ScalarValue,
) -> VortexResult<Self::NativeValue<'a>> {
Ok("")
}
}
impl Apples {
fn new() -> ExtDType<Apples> {
ExtDType::try_new(0, DType::Primitive(PType::U16, Nullability::NonNullable))
.vortex_expect("valid apples dtype")
}
}
#[test]
fn cast_to_from_extension_types() {
let apples = Apples::new();
let ext_dtype = DType::Extension(apples.clone().erased());
let ext_scalar = Scalar::new(
ext_dtype.clone(),
Some(ScalarValue::Primitive(PValue::U16(1000))),
);
let storage_scalar = Scalar::new(
DType::clone(apples.storage_dtype()),
Some(ScalarValue::Primitive(PValue::U16(1000))),
);
let expected_dtype = &ext_dtype;
let actual = ext_scalar.cast(expected_dtype).unwrap();
assert_eq!(actual.dtype(), expected_dtype);
let expected_dtype = &ext_dtype.as_nullable();
let actual = ext_scalar.cast(expected_dtype).unwrap();
assert_eq!(actual.dtype(), expected_dtype);
let expected_dtype = apples.storage_dtype();
let actual = ext_scalar.cast(expected_dtype).unwrap();
assert_eq!(actual.dtype(), expected_dtype);
let expected_dtype = &apples.storage_dtype().as_nullable();
let actual = ext_scalar.cast(expected_dtype).unwrap();
assert_eq!(actual.dtype(), expected_dtype);
let expected_dtype = &ext_dtype;
let actual = storage_scalar.cast(expected_dtype).unwrap();
assert_eq!(actual.dtype(), expected_dtype);
let expected_dtype = &ext_dtype.as_nullable();
let actual = storage_scalar.cast(expected_dtype).unwrap();
assert_eq!(actual.dtype(), expected_dtype);
let apples_u8 =
ExtDType::<Apples>::try_new(0, DType::Primitive(PType::U8, Nullability::NonNullable))
.unwrap();
let expected_dtype = &DType::Extension(apples_u8.erased());
let result = storage_scalar.cast(expected_dtype);
assert!(
result
.as_ref()
.is_err_and(|err| { err.to_string().contains("Cannot cast 1000u16 to u8") }),
"{result:?}"
);
}
#[test]
fn test_struct_field_coercion() {
let f16_value = f16::from_f32(0.42);
let f32_value = std::f32::consts::PI;
let struct_dtype = DType::Struct(
StructFields::from_iter([
(
"a",
FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
),
(
"b",
FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable)),
),
(
"c",
FieldDType::from(DType::Primitive(PType::F32, Nullability::NonNullable)),
),
]),
Nullability::NonNullable,
);
let field_values = vec![
Some(ScalarValue::Primitive(PValue::U32(42))),
Some(ScalarValue::Primitive(PValue::U64(
f16_value.to_bits() as u64
))),
Some(ScalarValue::Primitive(PValue::F32(f32_value))),
];
let scalar = Scalar::new(struct_dtype, Some(ScalarValue::List(field_values)));
let struct_scalar = scalar.as_struct();
let fields: Vec<_> = (0..3)
.map(|i| struct_scalar.field_by_idx(i).unwrap())
.collect();
assert_eq!(fields[0].as_primitive().pvalue().unwrap(), PValue::U32(42));
assert_eq!(
fields[1].as_primitive().pvalue().unwrap(),
PValue::F16(f16_value)
);
assert_eq!(
fields[2].as_primitive().pvalue().unwrap(),
PValue::F32(f32_value)
);
}
#[test]
fn test_fake_coercion_for_matching_type() {
let i32_value = 42i32;
let scalar = Scalar::new(
DType::Primitive(PType::I32, Nullability::NonNullable),
Some(ScalarValue::Primitive(PValue::I32(i32_value))),
);
assert_eq!(
scalar.as_primitive().pvalue().unwrap(),
PValue::I32(i32_value)
);
}
#[test]
fn test_list_element_coercion() {
let f16_value1 = f16::from_f32(1.0);
let f16_value2 = f16::from_f32(2.0);
let list_dtype = DType::List(
Arc::new(DType::Primitive(PType::F16, Nullability::NonNullable)),
Nullability::NonNullable,
);
let elements = vec![
Some(ScalarValue::Primitive(PValue::U64(
f16_value1.to_bits() as u64
))),
Some(ScalarValue::Primitive(PValue::U64(
f16_value2.to_bits() as u64
))),
];
let scalar = Scalar::new(list_dtype, Some(ScalarValue::List(elements)));
let list_scalar = scalar.as_list();
let elements = list_scalar.elements().unwrap();
for (i, expected) in [f16_value1, f16_value2].iter().enumerate() {
assert_eq!(
elements[i].as_primitive().pvalue().unwrap(),
PValue::F16(*expected)
);
}
}
#[test]
#[should_panic]
fn test_coercion_with_overflow_protection() {
let large_u64 = u64::MAX;
let scalar = Scalar::new(
DType::Primitive(PType::F16, Nullability::NonNullable),
Some(ScalarValue::Primitive(PValue::U64(large_u64))),
);
let _ = scalar.as_primitive(); }
#[test]
fn test_extension_dtype_coercion() {
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
struct F16Ext;
impl ExtVTable for F16Ext {
type Metadata = usize;
type NativeValue<'a> = &'a str;
fn id(&self) -> ExtId {
ExtId::new_ref("f16_ext")
}
fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
vortex_bail!("not implemented")
}
fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult<Self::Metadata> {
vortex_bail!("not implemented")
}
fn validate_dtype(_ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
Ok(())
}
fn unpack_native<'a>(
_ext_dtype: &'a ExtDType<Self>,
_storage_value: &'a ScalarValue,
) -> VortexResult<Self::NativeValue<'a>> {
Ok("")
}
}
let storage_dtype = DType::Primitive(PType::F16, Nullability::NonNullable);
let ext_dtype = ExtDType::<F16Ext>::try_new(0, storage_dtype).unwrap();
let f16_value = f16::from_f32(0.42);
let u64_bits = f16_value.to_bits() as u64;
let scalar = Scalar::new(
DType::Extension(ext_dtype.erased()),
Some(ScalarValue::Primitive(PValue::U64(u64_bits))),
);
assert_eq!(
scalar
.as_extension()
.to_storage_scalar()
.as_primitive()
.pvalue()
.unwrap(),
PValue::F16(f16_value)
);
}
#[test]
fn test_extension_dtype_nested_struct_coercion() {
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
struct StructExt;
impl ExtVTable for StructExt {
type Metadata = usize;
type NativeValue<'a> = &'a str;
fn id(&self) -> ExtId {
ExtId::new_ref("struct_ext")
}
fn serialize_metadata(&self, _metadata: &Self::Metadata) -> VortexResult<Vec<u8>> {
vortex_bail!("not implemented")
}
fn deserialize_metadata(&self, _data: &[u8]) -> VortexResult<Self::Metadata> {
vortex_bail!("not implemented")
}
fn validate_dtype(_ext_dtype: &ExtDType<Self>) -> VortexResult<()> {
Ok(())
}
fn unpack_native<'a>(
_ext_dtype: &'a ExtDType<Self>,
_storage_value: &'a ScalarValue,
) -> VortexResult<Self::NativeValue<'a>> {
Ok("")
}
}
let struct_dtype = DType::Struct(
StructFields::from_iter([
(
"id",
FieldDType::from(DType::Primitive(PType::U32, Nullability::NonNullable)),
),
(
"value",
FieldDType::from(DType::Primitive(PType::F16, Nullability::NonNullable)),
),
]),
Nullability::NonNullable,
);
let ext_dtype = ExtDType::<StructExt>::try_new(0, struct_dtype).unwrap();
let f16_value = f16::from_f32(1.5);
let field_values = vec![
Some(ScalarValue::Primitive(PValue::U32(123))),
Some(ScalarValue::Primitive(PValue::U64(
f16_value.to_bits() as u64
))),
];
let scalar = Scalar::new(
DType::Extension(ext_dtype.erased()),
Some(ScalarValue::List(field_values)),
);
let list_elems = scalar
.as_extension()
.to_storage_scalar()
.as_struct()
.fields_iter()
.vortex_expect("non null")
.collect::<Vec<_>>();
assert_eq!(
list_elems[0].as_primitive().pvalue().unwrap(),
PValue::U32(123)
);
assert_eq!(
list_elems[1].as_primitive().pvalue().unwrap(),
PValue::F16(f16_value)
);
}
}