vyre-libs 0.6.2

vyre Category A library ecosystem - pure-IR compositions over vyre-ops hardware primitives
Documentation
//! Integration test crate for the containing Vyre package.

use super::*;

#[test]
fn linear_tiled_matches_scalar_linear_with_bias_and_tail_tile() {
    let in_dim = 37_u32;
    let out_dim = 65_u32;
    let tile = 16_u32;
    let x = (0..in_dim)
        .map(|i| i.wrapping_mul(3).wrapping_add(1))
        .collect::<Vec<_>>();
    let w = (0..(in_dim * out_dim))
        .map(|i| i.wrapping_mul(5).wrapping_add(7))
        .collect::<Vec<_>>();
    let b = (0..out_dim)
        .map(|i| i.wrapping_mul(11).wrapping_add(13))
        .collect::<Vec<_>>();
    let to_bytes = vyre_primitives::wire::pack_u32_slice;
    let program = linear_tiled("x", "w", "b", "out", in_dim, out_dim, tile)
        .expect("Fix: linear_tiled must accept positive dimensions.");
    let outputs = vyre_reference::reference_eval(
        &program,
        &[
            Value::from(to_bytes(&x)),
            Value::from(to_bytes(&w)),
            Value::from(to_bytes(&b)),
            Value::from(output_zero_bytes(&program)),
        ],
    )
    .expect("Fix: linear_tiled must execute in the reference interpreter.");
    let actual = vyre_primitives::wire::decode_u32_le_bytes_all(&outputs[0].to_bytes());
    let expected = (0..out_dim as usize)
        .map(|j| {
            let mut acc = b[j];
            for k in 0..in_dim as usize {
                acc = acc.wrapping_add(x[k].wrapping_mul(w[k * out_dim as usize + j]));
            }
            acc
        })
        .collect::<Vec<_>>();
    assert_eq!(actual, expected);
}

#[test]
fn linear_tiled_accepts_logical_output_fixture_for_padded_storage() {
    let program = linear_tiled("x", "w", "b", "out", 4, 4, 2)
        .expect("Fix: linear_tiled must build for a 4x4 fixture.");
    let output = program
        .buffers()
        .iter()
        .find(|buffer| buffer.is_output())
        .expect("Fix: linear_tiled must declare an output buffer.");

    assert!(
        output.count() > 4,
        "Fix: linear_tiled should keep padded storage for cooperative tile launch geometry."
    );
    assert_eq!(
        output.output_byte_range(),
        Some(0..16),
        "Fix: linear_tiled must expose only the logical output bytes."
    );

    let x = crate::test_support::byte_pack::u32_bytes(&(0..4).collect::<Vec<_>>());
    let w = crate::test_support::byte_pack::u32_bytes(&(0..16).collect::<Vec<_>>());
    let bias = crate::test_support::byte_pack::u32_bytes(&[0, 0, 0, 0]);
    let outputs = vyre_reference::reference_eval(
        &program,
        &[
            Value::from(x),
            Value::from(w),
            Value::from(bias),
            Value::from(vec![0u8; 16]),
        ],
    )
    .expect("Fix: reference interpreter must pad output backing storage internally.");

    assert_eq!(
        outputs[0].to_bytes(),
        crate::test_support::byte_pack::u32_bytes(&[56, 62, 68, 74])
    );
}

#[test]
fn linear_tiled_matches_reference_on_adversarial_shapes() {
    let shapes = [
        (1_u32, 1_u32, 1_u32),
        (3, 5, 7),
        (15, 16, 17),
        (31, 32, 33),
        (63, 64, 65),
        (127, 128, 129),
        (255, 256, 257),
    ];
    let to_bytes = vyre_primitives::wire::pack_u32_slice;

    for (in_dim, out_dim, tile) in shapes {
        let x: Vec<u32> = (0..in_dim)
            .map(|i| i.wrapping_mul(3).wrapping_add(1))
            .collect();
        let w: Vec<u32> = (0..(in_dim * out_dim))
            .map(|i| i.wrapping_mul(5).wrapping_add(7))
            .collect();
        let b: Vec<u32> = (0..out_dim)
            .map(|i| i.wrapping_mul(11).wrapping_add(13))
            .collect();

        let opt = linear_tiled("x", "w", "b", "out", in_dim, out_dim, tile)
            .expect("Fix: linear_tiled must accept positive dimensions.");
        let ref_ = linear_tiled_reference("x", "w", "b", "out", in_dim, out_dim, tile)
            .expect("Fix: linear_tiled_reference must accept positive dimensions.");

        let opt_outputs = vyre_reference::reference_eval(
            &opt,
            &[
                Value::from(to_bytes(&x)),
                Value::from(to_bytes(&w)),
                Value::from(to_bytes(&b)),
                Value::from(output_zero_bytes(&opt)),
            ],
        )
        .expect("Fix: linear_tiled must execute.");
        let ref_outputs = vyre_reference::reference_eval(
            &ref_,
            &[
                Value::from(to_bytes(&x)),
                Value::from(to_bytes(&w)),
                Value::from(to_bytes(&b)),
                Value::from(vec![0u8; out_dim as usize * 4]),
            ],
        )
        .expect("Fix: linear_tiled_reference must execute.");

        assert_eq!(
            opt_outputs[0].to_bytes(),
            ref_outputs[0].to_bytes(),
            "linear_tiled must match linear_tiled_reference for (in_dim={in_dim}, out_dim={out_dim}, tile={tile})"
        );
    }
}

#[test]
fn linear_tiled_matches_reference_on_boundary_values() {
    let in_dim = 32_u32;
    let out_dim = 48_u32;
    let tile = 16_u32;
    let to_bytes = vyre_primitives::wire::pack_u32_slice;

    let boundary_cases: Vec<Vec<u32>> = vec![
        vec![0; in_dim as usize],
        vec![1; in_dim as usize],
        vec![u32::MAX; in_dim as usize],
        (0..in_dim).collect(),
    ];

    for x in &boundary_cases {
        let w: Vec<u32> = (0..(in_dim * out_dim))
            .map(|i| if i % 2 == 0 { 0 } else { 1 })
            .collect();
        let b: Vec<u32> = (0..out_dim).collect();

        let opt = linear_tiled("x", "w", "b", "out", in_dim, out_dim, tile)
            .expect("Fix: linear_tiled must accept positive dimensions.");
        let ref_ = linear_tiled_reference("x", "w", "b", "out", in_dim, out_dim, tile)
            .expect("Fix: linear_tiled_reference must accept positive dimensions.");

        let opt_outputs = vyre_reference::reference_eval(
            &opt,
            &[
                Value::from(to_bytes(x)),
                Value::from(to_bytes(&w)),
                Value::from(to_bytes(&b)),
                Value::from(output_zero_bytes(&opt)),
            ],
        )
        .expect("Fix: linear_tiled must execute.");
        let ref_outputs = vyre_reference::reference_eval(
            &ref_,
            &[
                Value::from(to_bytes(x)),
                Value::from(to_bytes(&w)),
                Value::from(to_bytes(&b)),
                Value::from(vec![0u8; out_dim as usize * 4]),
            ],
        )
        .expect("Fix: linear_tiled_reference must execute.");

        assert_eq!(
            opt_outputs[0].to_bytes(),
            ref_outputs[0].to_bytes(),
            "linear_tiled must match linear_tiled_reference for boundary input {:?}",
            x
        );
    }
}