vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use super::super::*;
use crate::spec::types::DataType;
use crate::spec::value::TensorValue;

#[test]
fn from_element_bytes_accepts_exact_length_u32() {
    assert_eq!(
        Value::from_element_bytes(DataType::U32, &[0x01, 0x00, 0x00, 0x00]).unwrap(),
        Value::U32(1)
    );
}

#[test]
fn from_element_bytes_rejects_short_u32() {
    let err = Value::from_element_bytes(DataType::U32, &[0x01, 0x00]).unwrap_err();
    assert_eq!(
        err,
        ValueError::TruncatedInput {
            expected: 4,
            actual: 2
        }
    );
}

#[test]
fn from_element_bytes_rejects_long_u32() {
    let err =
        Value::from_element_bytes(DataType::U32, &[0x01, 0x00, 0x00, 0x00, 0x00]).unwrap_err();
    assert_eq!(
        err,
        ValueError::TruncatedInput {
            expected: 4,
            actual: 5
        }
    );
}

#[test]
fn from_element_bytes_accepts_exact_length_i32() {
    assert_eq!(
        Value::from_element_bytes(DataType::I32, &[0xFF, 0xFF, 0xFF, 0xFF]).unwrap(),
        Value::I32(-1)
    );
}

#[test]
fn from_element_bytes_rejects_short_i32() {
    let err = Value::from_element_bytes(DataType::I32, &[0xFF]).unwrap_err();
    assert_eq!(
        err,
        ValueError::TruncatedInput {
            expected: 4,
            actual: 1
        }
    );
}

#[test]
fn from_element_bytes_accepts_exact_length_u64() {
    assert_eq!(
        Value::from_element_bytes(
            DataType::U64,
            &[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]
        )
        .unwrap(),
        Value::U64(1)
    );
}

#[test]
fn from_element_bytes_rejects_short_u64() {
    let err = Value::from_element_bytes(DataType::U64, &[0x01, 0x00, 0x00, 0x00]).unwrap_err();
    assert_eq!(
        err,
        ValueError::TruncatedInput {
            expected: 8,
            actual: 4
        }
    );
}

#[test]
fn from_element_bytes_accepts_exact_length_bool() {
    assert_eq!(
        Value::from_element_bytes(DataType::Bool, &[0x01, 0x00, 0x00, 0x00]).unwrap(),
        Value::Bool(true)
    );
}

#[test]
fn from_element_bytes_rejects_short_bool() {
    let err = Value::from_element_bytes(DataType::Bool, &[]).unwrap_err();
    assert_eq!(
        err,
        ValueError::TruncatedInput {
            expected: 4,
            actual: 0
        }
    );
}

#[test]
fn from_element_bytes_accepts_any_length_bytes() {
    assert_eq!(
        Value::from_element_bytes(DataType::Bytes, &[0xAB, 0xCD]).unwrap(),
        Value::Bytes(vec![0xAB, 0xCD])
    );
    assert_eq!(
        Value::from_element_bytes(DataType::Bytes, &[]).unwrap(),
        Value::Bytes(vec![])
    );
}

#[test]
fn from_element_bytes_accepts_exact_length_vec2u32() {
    assert_eq!(
        Value::from_element_bytes(
            DataType::Vec2U32,
            &[0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00]
        )
        .unwrap(),
        Value::Bytes(vec![0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00])
    );
}

#[test]
fn from_element_bytes_rejects_short_vec2u32() {
    let err = Value::from_element_bytes(DataType::Vec2U32, &[0x01, 0x00, 0x00, 0x00]).unwrap_err();
    assert_eq!(
        err,
        ValueError::TruncatedInput {
            expected: 8,
            actual: 4
        }
    );
}

#[test]
fn from_element_bytes_accepts_exact_length_vec4u32() {
    let bytes = vec![
        0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00,
        0x00,
    ];
    assert_eq!(
        Value::from_element_bytes(DataType::Vec4U32, &bytes).unwrap(),
        Value::Bytes(bytes)
    );
}

#[test]
fn from_element_bytes_rejects_short_vec4u32() {
    let err = Value::from_element_bytes(DataType::Vec4U32, &[0x01, 0x00, 0x00, 0x00]).unwrap_err();
    assert_eq!(
        err,
        ValueError::TruncatedInput {
            expected: 16,
            actual: 4
        }
    );
}

#[test]
fn from_element_type_tag_unknown_tag_still_errors_before_length_check() {
    let err = Value::from_element_type_tag(0xFF, &[0, 1, 2, 3]).unwrap_err();
    assert_eq!(err, ValueError::UnknownElementTypeTag);
}

#[test]
fn zero_for_float_types_returns_zero_float() {
    assert_eq!(Value::zero_for(DataType::F16), Value::Float(0.0));
    assert_eq!(Value::zero_for(DataType::BF16), Value::Float(0.0));
    assert_eq!(Value::zero_for(DataType::F32), Value::Float(0.0));
    assert_eq!(Value::zero_for(DataType::F64), Value::Float(0.0));
}

#[test]
fn zero_for_tensor_returns_1x1_zero_tensor() {
    let v = Value::zero_for(DataType::Tensor);
    assert_eq!(
        v,
        Value::Tensor(TensorValue {
            shape: vec![1, 1],
            element_type: DataType::F32,
            data: vec![0; 4],
        })
    );
}

#[test]
fn zero_for_array_empty_when_element_size_zero() {
    assert_eq!(
        Value::zero_for(DataType::Array { element_size: 0 }),
        Value::Array(vec![])
    );
}

#[test]
fn zero_for_array_with_one_element_when_nonzero_size() {
    assert_eq!(
        Value::zero_for(DataType::Array { element_size: 4 }),
        Value::Array(vec![Value::Bytes(vec![0; 4])])
    );
}

#[test]
fn from_element_bytes_float32_roundtrips() {
    let bits = 1.5f32.to_le_bytes();
    let v = Value::from_element_bytes(DataType::F32, &bits).unwrap();
    assert_eq!(v, Value::Float(1.5));
}

#[test]
fn from_element_bytes_float64_roundtrips() {
    let bits = 2.5f64.to_le_bytes();
    let v = Value::from_element_bytes(DataType::F64, &bits).unwrap();
    assert_eq!(v, Value::Float(2.5));
}

#[test]
fn from_element_type_tag_maps_new_float_tags() {
    assert_eq!(
        Value::from_element_type_tag(8, &[0, 0]).unwrap(),
        Value::Float(0.0)
    );
    assert_eq!(
        Value::from_element_type_tag(9, &[0, 0]).unwrap(),
        Value::Float(0.0)
    );
    assert_eq!(
        Value::from_element_type_tag(10, &[0, 0, 0, 0]).unwrap(),
        Value::Float(0.0)
    );
    assert_eq!(
        Value::from_element_type_tag(11, &[0, 0, 0, 0, 0, 0, 0, 0]).unwrap(),
        Value::Float(0.0)
    );
}

#[test]
fn from_element_type_tag_array_uses_wire_element_size() {
    let bytes = [
        4, 0, 0, 0, 0, 0, 0, 0, // wire element_size = 4
        1, 0, 0, 0, 0, 0, 0, 0, // count = 1
        0xAB, 0xCD, 0xEF, 0x00, // element 0
    ];
    let v = Value::from_element_type_tag(13, &bytes).unwrap();
    assert_eq!(
        v,
        Value::Array(vec![Value::Bytes(vec![0xAB, 0xCD, 0xEF, 0x00])])
    );
}

#[test]
fn from_element_type_tag_tensor_parses_header() {
    let bytes = [
        2u8, // rank
        2, 0, 0, 0, 0, 0, 0, 0, // dim0
        3, 0, 0, 0, 0, 0, 0, 0,  // dim1
        10, // element tag F32
        // 2 * 3 * 4 = 24 payload bytes (f32 zero words).
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
    ];
    let v = Value::from_element_type_tag(12, &bytes).unwrap();
    assert_eq!(
        v,
        Value::Tensor(TensorValue {
            shape: vec![2, 3],
            element_type: DataType::F32,
            data: vec![0; 24],
        })
    );
}

#[test]
fn from_element_bytes_array_parses_length_prefix() {
    let bytes = [
        4, 0, 0, 0, 0, 0, 0, 0, // element_size = 4
        2, 0, 0, 0, 0, 0, 0, 0, // count = 2
        0x01, 0x00, 0x00, 0x00, // element 0
        0x02, 0x00, 0x00, 0x00, // element 1
    ];
    let v = Value::from_element_bytes(DataType::Array { element_size: 4 }, &bytes).unwrap();
    assert_eq!(
        v,
        Value::Array(vec![
            Value::Bytes(vec![0x01, 0x00, 0x00, 0x00]),
            Value::Bytes(vec![0x02, 0x00, 0x00, 0x00]),
        ])
    );
}

#[test]
fn float_equality_uses_bit_pattern() {
    assert_eq!(Value::Float(f64::NAN), Value::Float(f64::NAN));
    assert_ne!(Value::Float(f64::NAN), Value::Float(0.0));
}