use vyre_spec::data_type::DataType;
#[test]
fn all_scalar_variants_are_distinct() {
assert_ne!(DataType::U8, DataType::U16);
assert_ne!(DataType::U16, DataType::U32);
assert_ne!(DataType::U32, DataType::U64);
assert_ne!(DataType::I8, DataType::I16);
assert_ne!(DataType::I16, DataType::I32);
assert_ne!(DataType::I32, DataType::I64);
assert_ne!(DataType::F16, DataType::F32);
assert_ne!(DataType::F32, DataType::F64);
assert_ne!(DataType::Bool, DataType::U32);
}
#[test]
fn u32_size_is_4_bytes() {
assert_eq!(DataType::U32.size_bytes(), Some(4));
}
#[test]
fn f32_size_is_4_bytes() {
assert_eq!(DataType::F32.size_bytes(), Some(4));
}
#[test]
fn f64_size_is_8_bytes() {
assert_eq!(DataType::F64.size_bytes(), Some(8));
}
#[test]
fn f32_is_float_family() {
assert!(DataType::F32.is_float_family());
}
#[test]
fn f16_is_float_family() {
assert!(DataType::F16.is_float_family());
}
#[test]
fn bf16_is_float_family() {
assert!(DataType::BF16.is_float_family());
}
#[test]
fn u32_is_not_float_family() {
assert!(!DataType::U32.is_float_family());
}
#[test]
fn bool_is_not_float_family() {
assert!(!DataType::Bool.is_float_family());
}
#[test]
fn u32_element_size_is_none() {
assert_eq!(DataType::U32.element_size(), None);
}
#[test]
fn vec2_u32_size_is_8() {
assert_eq!(DataType::Vec2U32.size_bytes(), Some(8));
}
#[test]
fn vec4_u32_size_is_16() {
assert_eq!(DataType::Vec4U32.size_bytes(), Some(16));
}
#[test]
fn min_bytes_for_u32_is_4() {
assert_eq!(DataType::U32.min_bytes(), 4);
}
#[test]
fn max_bytes_for_scalar_is_some() {
assert_eq!(DataType::U32.max_bytes(), Some(4));
}
#[test]
fn i64_size_is_8() {
assert_eq!(DataType::I64.size_bytes(), Some(8));
}
#[test]
fn f8e4m3_is_float_family() {
assert!(DataType::F8E4M3.is_float_family());
}
#[test]
fn f8e5m2_is_float_family() {
assert!(DataType::F8E5M2.is_float_family());
}
#[test]
fn bytes_has_no_element_size() {
assert_eq!(DataType::Bytes.element_size(), None);
}
#[test]
fn bytes_size_is_1() {
assert_eq!(DataType::Bytes.size_bytes(), Some(1));
}
#[test]
fn tensor_has_no_size_bytes() {
assert_eq!(DataType::Tensor.size_bytes(), None);
}
#[test]
fn array_size_roundtrips() {
let dt = DataType::Array { element_size: 16 };
assert_eq!(dt.element_size(), Some(16));
}