vyre-foundation 0.6.1

Foundation layer: IR, type system, memory model, wire format. Zero application semantics. Part of the vyre GPU compiler.
Documentation
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Node, Program};

#[test]
fn to_wire_into_appends_byte_for_byte() {
    let program = Program::wrapped(
        vec![
            BufferDecl::read_write("a", 0, DataType::U32),
            BufferDecl::read("b", 1, DataType::U32),
        ],
        [64, 1, 1],
        vec![
            Node::let_bind("idx", Expr::gid_x()),
            Node::store("a", Expr::var("idx"), Expr::load("b", Expr::var("idx"))),
        ],
    );

    let mut separate = Vec::new();
    for _ in 0..100 {
        separate.extend_from_slice(&to_wire(&program).unwrap());
    }

    let mut reused = Vec::new();
    for _ in 0..100 {
        to_wire_into(&program, &mut reused).unwrap();
    }

    assert_eq!(
        separate, reused,
        "100 separate to_wire calls must match 100 to_wire_into calls into the same buffer"
    );
}

#[test]
fn encode_section_helpers_reuse_caller_scratch() {
    let program = Program::wrapped(
        vec![
            BufferDecl::read_write("a", 0, DataType::U32).with_count(64),
            BufferDecl::read("b", 1, DataType::U32).with_count(64),
            BufferDecl::read("mask", 2, DataType::Bool).with_count(64),
        ],
        [64, 1, 1],
        vec![
            Node::let_bind("idx", Expr::gid_x()),
            Node::store("a", Expr::var("idx"), Expr::load("b", Expr::var("idx"))),
        ],
    );

    let mut out = Vec::with_capacity(2048);
    let mut payload = Vec::with_capacity(2048);
    put_nodes_section_with_payload(&mut out, &program, program.buffers(), &mut payload)
        .expect("Fix: node section must encode");
    let payload_ptr = payload.as_ptr();
    let payload_capacity = payload.capacity();
    out.clear();
    put_nodes_section_with_payload(&mut out, &program, program.buffers(), &mut payload)
        .expect("Fix: node section must encode a second time");
    assert_eq!(payload.as_ptr(), payload_ptr);
    assert_eq!(payload.capacity(), payload_capacity);

    let mut shape = Vec::with_capacity(64);
    let mut hints = Vec::with_capacity(64);
    put_memory_regions_with_scratch(&mut out, program.buffers(), &mut shape, &mut hints)
        .expect("Fix: memory regions must encode");
    let shape_ptr = shape.as_ptr();
    let hints_ptr = hints.as_ptr();
    let shape_capacity = shape.capacity();
    let hints_capacity = hints.capacity();
    out.clear();
    put_memory_regions_with_scratch(&mut out, program.buffers(), &mut shape, &mut hints)
        .expect("Fix: memory regions must encode a second time");
    assert_eq!(shape.as_ptr(), shape_ptr);
    assert_eq!(hints.as_ptr(), hints_ptr);
    assert_eq!(shape.capacity(), shape_capacity);
    assert_eq!(hints.capacity(), hints_capacity);
}

#[test]
fn output_set_is_serialized_and_validated() {
    let program = Program::wrapped(
        vec![
            BufferDecl::read("input", 0, DataType::U32).with_count(4),
            BufferDecl::output("out", 1, DataType::U32).with_count(4),
            BufferDecl::read_write("scratch_out", 2, DataType::U32).with_count(4),
        ],
        [64, 1, 1],
        vec![Node::store(
            "out",
            Expr::u32(0),
            Expr::load("input", Expr::u32(0)),
        )],
    );

    let encoded = to_wire(&program).expect("Fix: output-set program must encode");
    assert_eq!(
        &encoded[encoded.len() - 3..],
        &[2, 1, 2],
        "OutputSet must list the two writable buffer indices in declaration order"
    );
    let decoded =
        Program::from_wire(&encoded).expect("Fix: encoded output-set program must decode");
    assert_eq!(decoded.output_buffer_indices(), &[1, 2]);

    let mut tampered = encoded;
    let last = tampered.len() - 1;
    tampered[last] = 0;
    let digest = blake3::hash(&tampered[40..]);
    tampered[8..40].copy_from_slice(digest.as_bytes());
    let err = Program::from_wire(&tampered)
        .expect_err("tampered output-set must be rejected")
        .to_string();
    assert!(
        err.contains("output-set"),
        "decode error must name the corrupt OutputSet: {err}"
    );
}