vyre-spec 0.6.2

Frozen data contracts for vyre - OpDef, AlgebraicLaw, Category, IntrinsicTable
Documentation
use crate::{DataType, QuantizationScale, QuantizationZeroPoint, TypeId};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::collections::BTreeMap;
use std::fmt::Debug;

pub(super) fn data_type_wire_cases() -> Vec<(&'static str, DataType, u8)> {
    vec![
        ("U32", DataType::U32, 0x01),
        ("I32", DataType::I32, 0x02),
        ("U64", DataType::U64, 0x03),
        ("Vec2U32", DataType::Vec2U32, 0x04),
        ("Vec4U32", DataType::Vec4U32, 0x05),
        ("Bool", DataType::Bool, 0x06),
        ("Bytes", DataType::Bytes, 0x07),
        ("Array", DataType::Array { element_size: 4 }, 0x08),
        ("F16", DataType::F16, 0x09),
        ("BF16", DataType::BF16, 0x0A),
        ("F32", DataType::F32, 0x0B),
        ("F64", DataType::F64, 0x0C),
        ("Tensor", DataType::Tensor, 0x0D),
        ("U8", DataType::U8, 0x0E),
        ("U16", DataType::U16, 0x0F),
        ("I8", DataType::I8, 0x10),
        ("I16", DataType::I16, 0x11),
        ("I64", DataType::I64, 0x12),
        ("Handle", DataType::Handle(TypeId(9)), 0x13),
        (
            "Vec",
            DataType::Vec {
                element: Box::new(DataType::U16),
                count: 8,
            },
            0x14,
        ),
        (
            "TensorShaped",
            DataType::TensorShaped {
                element: Box::new(DataType::F32),
                shape: smallvec::smallvec![4, 16, 32],
            },
            0x15,
        ),
        (
            "SparseCsr",
            DataType::SparseCsr {
                element: Box::new(DataType::F32),
            },
            0x16,
        ),
        (
            "SparseCoo",
            DataType::SparseCoo {
                element: Box::new(DataType::I32),
            },
            0x17,
        ),
        (
            "SparseBsr",
            DataType::SparseBsr {
                element: Box::new(DataType::I4),
                block_rows: 8,
                block_cols: 8,
            },
            0x18,
        ),
        ("F8E4M3", DataType::F8E4M3, 0x19),
        ("F8E5M2", DataType::F8E5M2, 0x1A),
        ("I4", DataType::I4, 0x1B),
        ("FP4", DataType::FP4, 0x1C),
        ("NF4", DataType::NF4, 0x1D),
        (
            "DeviceMesh",
            DataType::DeviceMesh {
                axes: smallvec::smallvec![2, 4],
            },
            0x1E,
        ),
        (
            "Quantized",
            DataType::Quantized {
                storage: Box::new(DataType::I8),
                scale: QuantizationScale::PerChannel { axis: 0 },
                zero_point: QuantizationZeroPoint::PerGroup { group_size: 32 },
            },
            0x1F,
        ),
    ]
}

pub(super) fn assert_payload_independent_tag(name: &str, values: &[DataType], expected: u8) {
    for value in values {
        assert_eq!(
            value.builtin_wire_tag(),
            Some(expected),
            "Fix: DataType::{name} tag must not depend on payload fields"
        );
    }
}

pub(super) fn assert_pairwise_unique_tags(family: &str, tags: &[(&'static str, u8)]) {
    let mut seen = BTreeMap::new();
    for (left_index, (left_name, left_tag)) in tags.iter().enumerate() {
        assert!(
            (0x01..=0x7F).contains(left_tag),
            "Fix: {family}::{left_name} uses non-builtin tag {left_tag:#04x}"
        );
        assert_eq!(
            seen.insert(*left_tag, *left_name),
            None,
            "Fix: {family}::{left_name} duplicates wire tag {left_tag:#04x}"
        );
        for (right_index, (right_name, right_tag)) in tags.iter().enumerate() {
            if left_index == right_index {
                assert_eq!(left_tag, right_tag);
            } else {
                assert_ne!(
                    left_tag, right_tag,
                    "Fix: {family}::{left_name} and {family}::{right_name} share wire tag {left_tag:#04x}"
                );
            }
        }
    }
}

pub(super) fn assert_json_roundtrip<T>(value: T)
where
    T: Serialize + DeserializeOwned + PartialEq + Debug,
{
    let encoded = serde_json::to_string(&value)
        .expect("Fix: representative frozen contract value must serialize to JSON");
    let decoded = serde_json::from_str::<T>(&encoded)
        .expect("Fix: representative frozen contract value must deserialize from JSON");
    assert_eq!(
        decoded, value,
        "Fix: JSON round-trip drifted for representative contract value"
    );
}