use super::super::*;
use crate::spec::types::DataType;
use crate::spec::value::TensorValue;
#[test]
fn from_element_bytes_accepts_exact_length_u32() {
assert_eq!(
Value::from_element_bytes(DataType::U32, &[0x01, 0x00, 0x00, 0x00]).unwrap(),
Value::U32(1)
);
}
#[test]
fn from_element_bytes_rejects_short_u32() {
let err = Value::from_element_bytes(DataType::U32, &[0x01, 0x00]).unwrap_err();
assert_eq!(
err,
ValueError::TruncatedInput {
expected: 4,
actual: 2
}
);
}
#[test]
fn from_element_bytes_rejects_long_u32() {
let err =
Value::from_element_bytes(DataType::U32, &[0x01, 0x00, 0x00, 0x00, 0x00]).unwrap_err();
assert_eq!(
err,
ValueError::TruncatedInput {
expected: 4,
actual: 5
}
);
}
#[test]
fn from_element_bytes_accepts_exact_length_i32() {
assert_eq!(
Value::from_element_bytes(DataType::I32, &[0xFF, 0xFF, 0xFF, 0xFF]).unwrap(),
Value::I32(-1)
);
}
#[test]
fn from_element_bytes_rejects_short_i32() {
let err = Value::from_element_bytes(DataType::I32, &[0xFF]).unwrap_err();
assert_eq!(
err,
ValueError::TruncatedInput {
expected: 4,
actual: 1
}
);
}
#[test]
fn from_element_bytes_accepts_exact_length_u64() {
assert_eq!(
Value::from_element_bytes(
DataType::U64,
&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
)
.unwrap(),
Value::U64(1)
);
}
#[test]
fn from_element_bytes_rejects_short_u64() {
let err = Value::from_element_bytes(DataType::U64, &[0x01, 0x00, 0x00, 0x00]).unwrap_err();
assert_eq!(
err,
ValueError::TruncatedInput {
expected: 8,
actual: 4
}
);
}
#[test]
fn from_element_bytes_accepts_exact_length_bool() {
assert_eq!(
Value::from_element_bytes(DataType::Bool, &[0x01, 0x00, 0x00, 0x00]).unwrap(),
Value::Bool(true)
);
}
#[test]
fn from_element_bytes_rejects_short_bool() {
let err = Value::from_element_bytes(DataType::Bool, &[]).unwrap_err();
assert_eq!(
err,
ValueError::TruncatedInput {
expected: 4,
actual: 0
}
);
}
#[test]
fn from_element_bytes_accepts_any_length_bytes() {
assert_eq!(
Value::from_element_bytes(DataType::Bytes, &[0xAB, 0xCD]).unwrap(),
Value::Bytes(vec![0xAB, 0xCD])
);
assert_eq!(
Value::from_element_bytes(DataType::Bytes, &[]).unwrap(),
Value::Bytes(vec![])
);
}
#[test]
fn from_element_bytes_accepts_exact_length_vec2u32() {
assert_eq!(
Value::from_element_bytes(
DataType::Vec2U32,
&[0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00]
)
.unwrap(),
Value::Bytes(vec![0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00])
);
}
#[test]
fn from_element_bytes_rejects_short_vec2u32() {
let err = Value::from_element_bytes(DataType::Vec2U32, &[0x01, 0x00, 0x00, 0x00]).unwrap_err();
assert_eq!(
err,
ValueError::TruncatedInput {
expected: 8,
actual: 4
}
);
}
#[test]
fn from_element_bytes_accepts_exact_length_vec4u32() {
let bytes = vec![
0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00,
0x00,
];
assert_eq!(
Value::from_element_bytes(DataType::Vec4U32, &bytes).unwrap(),
Value::Bytes(bytes)
);
}
#[test]
fn from_element_bytes_rejects_short_vec4u32() {
let err = Value::from_element_bytes(DataType::Vec4U32, &[0x01, 0x00, 0x00, 0x00]).unwrap_err();
assert_eq!(
err,
ValueError::TruncatedInput {
expected: 16,
actual: 4
}
);
}
#[test]
fn from_element_type_tag_unknown_tag_still_errors_before_length_check() {
let err = Value::from_element_type_tag(0xFF, &[0, 1, 2, 3]).unwrap_err();
assert_eq!(err, ValueError::UnknownElementTypeTag);
}
#[test]
fn zero_for_float_types_returns_zero_float() {
assert_eq!(Value::zero_for(DataType::F16), Value::Float(0.0));
assert_eq!(Value::zero_for(DataType::BF16), Value::Float(0.0));
assert_eq!(Value::zero_for(DataType::F32), Value::Float(0.0));
assert_eq!(Value::zero_for(DataType::F64), Value::Float(0.0));
}
#[test]
fn zero_for_tensor_returns_1x1_zero_tensor() {
let v = Value::zero_for(DataType::Tensor);
assert_eq!(
v,
Value::Tensor(TensorValue {
shape: vec![1, 1],
element_type: DataType::F32,
data: vec![0; 4],
})
);
}
#[test]
fn zero_for_array_empty_when_element_size_zero() {
assert_eq!(
Value::zero_for(DataType::Array { element_size: 0 }),
Value::Array(vec![])
);
}
#[test]
fn zero_for_array_with_one_element_when_nonzero_size() {
assert_eq!(
Value::zero_for(DataType::Array { element_size: 4 }),
Value::Array(vec![Value::Bytes(vec![0; 4])])
);
}
#[test]
fn from_element_bytes_float32_roundtrips() {
let bits = 1.5f32.to_le_bytes();
let v = Value::from_element_bytes(DataType::F32, &bits).unwrap();
assert_eq!(v, Value::Float(1.5));
}
#[test]
fn from_element_bytes_float64_roundtrips() {
let bits = 2.5f64.to_le_bytes();
let v = Value::from_element_bytes(DataType::F64, &bits).unwrap();
assert_eq!(v, Value::Float(2.5));
}
#[test]
fn from_element_type_tag_maps_new_float_tags() {
assert_eq!(
Value::from_element_type_tag(8, &[0, 0]).unwrap(),
Value::Float(0.0)
);
assert_eq!(
Value::from_element_type_tag(9, &[0, 0]).unwrap(),
Value::Float(0.0)
);
assert_eq!(
Value::from_element_type_tag(10, &[0, 0, 0, 0]).unwrap(),
Value::Float(0.0)
);
assert_eq!(
Value::from_element_type_tag(11, &[0, 0, 0, 0, 0, 0, 0, 0]).unwrap(),
Value::Float(0.0)
);
}
#[test]
fn from_element_type_tag_array_uses_wire_element_size() {
let bytes = [
4, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0xAB, 0xCD, 0xEF, 0x00, ];
let v = Value::from_element_type_tag(13, &bytes).unwrap();
assert_eq!(
v,
Value::Array(vec![Value::Bytes(vec![0xAB, 0xCD, 0xEF, 0x00])])
);
}
#[test]
fn from_element_type_tag_tensor_parses_header() {
let bytes = [
2u8, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
];
let v = Value::from_element_type_tag(12, &bytes).unwrap();
assert_eq!(
v,
Value::Tensor(TensorValue {
shape: vec![2, 3],
element_type: DataType::F32,
data: vec![0; 24],
})
);
}
#[test]
fn from_element_bytes_array_parses_length_prefix() {
let bytes = [
4, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, ];
let v = Value::from_element_bytes(DataType::Array { element_size: 4 }, &bytes).unwrap();
assert_eq!(
v,
Value::Array(vec![
Value::Bytes(vec![0x01, 0x00, 0x00, 0x00]),
Value::Bytes(vec![0x02, 0x00, 0x00, 0x00]),
])
);
}
#[test]
fn float_equality_uses_bit_pattern() {
assert_eq!(Value::Float(f64::NAN), Value::Float(f64::NAN));
assert_ne!(Value::Float(f64::NAN), Value::Float(0.0));
}