vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use crate::spec::types::DataType;
use crate::spec::value::error::ValueError;
use crate::spec::value::tensor::TensorValue;
use crate::spec::value::value_enum::Value;

fn parse_tensor(bytes: &[u8]) -> Result<TensorValue, ValueError> {
    // Wire format: [rank: u8, shape: rank*u64 LE, element_tag: u8, payload...]
    if bytes.len() < 2 {
        return Err(ValueError::TruncatedInput {
            expected: 2,
            actual: bytes.len(),
        });
    }
    let rank = bytes[0] as usize;
    let header_len = 1usize
        .checked_add(rank.checked_mul(8).ok_or(ValueError::TruncatedInput {
            expected: usize::MAX,
            actual: bytes.len(),
        })?)
        .and_then(|n| n.checked_add(1))
        .ok_or(ValueError::TruncatedInput {
            expected: usize::MAX,
            actual: bytes.len(),
        })?;
    if bytes.len() < header_len {
        return Err(ValueError::TruncatedInput {
            expected: header_len,
            actual: bytes.len(),
        });
    }
    let mut shape = Vec::with_capacity(rank);
    for i in 0..rank {
        let offset = 1 + i * 8;
        let dim = u64::from_le_bytes([
            bytes[offset],
            bytes[offset + 1],
            bytes[offset + 2],
            bytes[offset + 3],
            bytes[offset + 4],
            bytes[offset + 5],
            bytes[offset + 6],
            bytes[offset + 7],
        ]);
        shape.push(dim);
    }
    let element_tag = bytes[1 + rank * 8];
    let element_type = match element_tag {
        0 => DataType::U32,
        1 => DataType::I32,
        2 => DataType::U64,
        5 => DataType::Bool,
        8 => DataType::F16,
        9 => DataType::BF16,
        10 => DataType::F32,
        11 => DataType::F64,
        _ => return Err(ValueError::UnknownElementTypeTag),
    };
    let payload = &bytes[header_len..];
    let element_size = element_type.min_bytes();
    if element_size == 0 {
        return Err(ValueError::UnsupportedElementType);
    }
    let element_count = shape.iter().try_fold(1usize, |acc, dim| {
        let dim = usize::try_from(*dim).map_err(|_| ValueError::TruncatedInput {
            expected: usize::MAX,
            actual: payload.len(),
        })?;
        acc.checked_mul(dim).ok_or(ValueError::TruncatedInput {
            expected: usize::MAX,
            actual: payload.len(),
        })
    })?;
    let expected_payload =
        element_count
            .checked_mul(element_size)
            .ok_or(ValueError::TruncatedInput {
                expected: usize::MAX,
                actual: payload.len(),
            })?;
    if payload.len() != expected_payload {
        return Err(ValueError::TruncatedInput {
            expected: header_len + expected_payload,
            actual: bytes.len(),
        });
    }
    Ok(TensorValue {
        shape,
        element_type,
        data: payload.to_vec(),
    })
}