vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use crate::spec::types::DataType;

use super::error::ValueError;
use super::tensor::TensorValue;
use super::value_enum::Value;

impl Value {
    /// Build a scalar value from a conformance element type and little-endian bytes.
    ///
    /// Decodes a raw byte slice into a typed value so that the reference interpreter can process external data payloads.
    ///
    /// # Returns
    ///
    /// Returns `Ok(Value)` if the bytes are successfully decoded. Returns `Err(ValueError::TruncatedInput)` if the byte slice length does not exactly match the element's minimum byte requirement.
    ///
    /// # Examples
    ///
    /// Decode a valid 32-bit unsigned integer from a 4-byte slice:
    ///
    /// ```rust
    /// # use vyre_conform::spec::value::Value;
    /// # use vyre_conform::types::DataType;
    /// let bytes = [0x01, 0x00, 0x00, 0x00];
    /// let val = Value::from_element_bytes(DataType::U32, &bytes).expect("valid bytes");
    /// assert_eq!(val, Value::U32(1));
    /// ```
    #[inline]
    pub fn from_element_bytes(ty: DataType, bytes: &[u8]) -> Result<Self, ValueError> {
        if ty.min_bytes() > 0 && bytes.len() != ty.min_bytes() {
            return Err(ValueError::TruncatedInput {
                expected: ty.min_bytes(),
                actual: bytes.len(),
            });
        }
        Ok(match ty {
            DataType::U32 => Self::U32(u32::from_le_bytes(read_array4(bytes)?)),
            DataType::I32 => Self::I32(i32::from_le_bytes(read_array4(bytes)?)),
            DataType::U64 => Self::U64(u64::from_le_bytes(read_array8(bytes)?)),
            DataType::Bool => Self::Bool(u32::from_le_bytes(read_array4(bytes)?) != 0),
            DataType::Bytes | DataType::Vec2U32 | DataType::Vec4U32 => Self::Bytes(bytes.to_vec()),
            DataType::F16 | DataType::BF16 => {
                let bits = u16::from_le_bytes(read_array2(bytes)?);
                // Re-interpret the bit pattern as an f32 value.
                // For BF16 we promote to f32 by shifting left 16 bits.
                let f32_bits = if matches!(ty, DataType::BF16) {
                    u32::from(bits) << 16
                } else {
                    // F16: simple promotion via half-crate logic is unavailable,
                    // so we approximate by converting the u16 bits to f32.
                    // This is sufficient for zero/non-zero classification.
                    u32::from(bits)
                };
                Self::Float(f32::from_bits(f32_bits) as f64)
            }
            DataType::F32 => Self::Float(f32::from_le_bytes(read_array4(bytes)?) as f64),
            DataType::F64 => Self::Float(f64::from_le_bytes(read_array8(bytes)?)),
            DataType::Tensor => Self::Tensor(parse_tensor(bytes)?),
            DataType::Array { element_size } => Self::Array(parse_array(bytes, element_size)?),
        })
    }

    /// Build a scalar value from a stable wire-format element type tag.
    ///
    /// Constructs a value from a stable integer tag and raw bytes so that binary wire-format messages can be directly ingested into the test environment.
    ///
    /// # Returns
    ///
    /// Returns `Ok(Value)` on success. Returns `Err(ValueError::UnknownElementTypeTag)` when `tag` does not map to a supported [`DataType`]. Returns `Err(ValueError::TruncatedInput)` if the bytes length is incorrect.
    ///
    /// # Examples
    ///
    /// Construct a signed 32-bit integer value from the standard type tag `1`:
    ///
    /// ```rust
    /// # use vyre_conform::spec::value::Value;
    /// let bytes = [0xff, 0xff, 0xff, 0xff];
    /// let val = Value::from_element_type_tag(1, &bytes).expect("valid bytes");
    /// assert_eq!(val, Value::I32(-1));
    /// ```
    #[inline]
    pub fn from_element_type_tag(tag: u8, bytes: &[u8]) -> Result<Self, ValueError> {
        let ty = match tag {
            0 => DataType::U32,
            1 => DataType::I32,
            2 => DataType::U64,
            3 => DataType::Vec2U32,
            4 => DataType::Vec4U32,
            5 => DataType::Bool,
            6 => DataType::Bytes,
            8 => DataType::F16,
            9 => DataType::BF16,
            10 => DataType::F32,
            11 => DataType::F64,
            12 => DataType::Tensor,
            13 => DataType::Array { element_size: 0 },
            _ => return Err(ValueError::UnknownElementTypeTag),
        };
        Self::from_element_bytes(ty, bytes)
    }
}

fn read_array2(bytes: &[u8]) -> Result<[u8; 2], ValueError> {
    if bytes.len() != 2 {
        return Err(ValueError::TruncatedInput {
            expected: 2,
            actual: bytes.len(),
        });
    }
    let mut arr = [0u8; 2];
    arr.copy_from_slice(bytes);
    Ok(arr)
}

fn read_array4(bytes: &[u8]) -> Result<[u8; 4], ValueError> {
    if bytes.len() != 4 {
        return Err(ValueError::TruncatedInput {
            expected: 4,
            actual: bytes.len(),
        });
    }
    let mut arr = [0u8; 4];
    arr.copy_from_slice(bytes);
    Ok(arr)
}

fn read_array8(bytes: &[u8]) -> Result<[u8; 8], ValueError> {
    if bytes.len() != 8 {
        return Err(ValueError::TruncatedInput {
            expected: 8,
            actual: bytes.len(),
        });
    }
    let mut arr = [0u8; 8];
    arr.copy_from_slice(bytes);
    Ok(arr)
}

fn parse_tensor(bytes: &[u8]) -> Result<TensorValue, ValueError> {
    // Wire format: [rank: u8, shape: rank*u64 LE, element_tag: u8, payload...]
    if bytes.len() < 2 {
        return Err(ValueError::TruncatedInput {
            expected: 2,
            actual: bytes.len(),
        });
    }
    let rank = bytes[0] as usize;
    let header_len = 1usize
        .checked_add(rank.checked_mul(8).ok_or(ValueError::TruncatedInput {
            expected: usize::MAX,
            actual: bytes.len(),
        })?)
        .and_then(|n| n.checked_add(1))
        .ok_or(ValueError::TruncatedInput {
            expected: usize::MAX,
            actual: bytes.len(),
        })?;
    if bytes.len() < header_len {
        return Err(ValueError::TruncatedInput {
            expected: header_len,
            actual: bytes.len(),
        });
    }
    let mut shape = Vec::with_capacity(rank);
    for i in 0..rank {
        let offset = 1 + i * 8;
        let dim = u64::from_le_bytes([
            bytes[offset],
            bytes[offset + 1],
            bytes[offset + 2],
            bytes[offset + 3],
            bytes[offset + 4],
            bytes[offset + 5],
            bytes[offset + 6],
            bytes[offset + 7],
        ]);
        shape.push(dim);
    }
    let element_tag = bytes[1 + rank * 8];
    let element_type = match element_tag {
        0 => DataType::U32,
        1 => DataType::I32,
        2 => DataType::U64,
        5 => DataType::Bool,
        8 => DataType::F16,
        9 => DataType::BF16,
        10 => DataType::F32,
        11 => DataType::F64,
        _ => return Err(ValueError::UnknownElementTypeTag),
    };
    let payload = &bytes[header_len..];
    let element_size = element_type.min_bytes();
    if element_size == 0 {
        return Err(ValueError::UnsupportedElementType);
    }
    let element_count = shape.iter().try_fold(1usize, |acc, dim| {
        let dim = usize::try_from(*dim).map_err(|_| ValueError::TruncatedInput {
            expected: usize::MAX,
            actual: payload.len(),
        })?;
        acc.checked_mul(dim).ok_or(ValueError::TruncatedInput {
            expected: usize::MAX,
            actual: payload.len(),
        })
    })?;
    let expected_payload =
        element_count
            .checked_mul(element_size)
            .ok_or(ValueError::TruncatedInput {
                expected: usize::MAX,
                actual: payload.len(),
            })?;
    if payload.len() != expected_payload {
        return Err(ValueError::TruncatedInput {
            expected: header_len + expected_payload,
            actual: bytes.len(),
        });
    }
    Ok(TensorValue {
        shape,
        element_type,
        data: payload.to_vec(),
    })
}

fn parse_array(bytes: &[u8], element_size: usize) -> Result<Vec<Value>, ValueError> {
    // Wire format: [element_size: u64 LE, count: u64 LE, payload...]
    if bytes.len() < 16 {
        return Err(ValueError::TruncatedInput {
            expected: 16,
            actual: bytes.len(),
        });
    }
    let wire_element_size = u64::from_le_bytes([
        bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
    ]) as usize;
    // When the declared element_size is zero (e.g. from a bare tag-13 dispatch),
    // trust the wire-encoded size.
    let effective_size = if element_size == 0 {
        wire_element_size
    } else {
        if wire_element_size != element_size {
            return Err(ValueError::MismatchedElementSize {
                declared: element_size,
                wire: wire_element_size,
            });
        }
        element_size
    };
    let count = u64::from_le_bytes([
        bytes[8], bytes[9], bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15],
    ]) as usize;
    let payload = &bytes[16..];
    let expected_payload = count.checked_mul(effective_size).unwrap_or(0);
    if payload.len() != expected_payload {
        return Err(ValueError::TruncatedInput {
            expected: 16 + expected_payload,
            actual: bytes.len(),
        });
    }
    let mut elements = Vec::with_capacity(count);
    for i in 0..count {
        let start = i * effective_size;
        let end = start + effective_size;
        elements.push(Value::Bytes(payload[start..end].to_vec()));
    }
    Ok(elements)
}