vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Reference interpreter self-tests for all primitive.float ops.

use vyre::ops::primitive::float::*;
use vyre_conform::{reference::interp, spec::value::Value};

fn f32_bytes(values: &[f32]) -> Vec<u8> {
    values.iter().flat_map(|v| v.to_le_bytes()).collect()
}

fn u32_bytes(values: &[u32]) -> Vec<u8> {
    values.iter().flat_map(|v| v.to_le_bytes()).collect()
}

fn i32_bytes(values: &[i32]) -> Vec<u8> {
    values.iter().flat_map(|v| v.to_le_bytes()).collect()
}

fn read_f32_bytes(bytes: &[u8]) -> Vec<f32> {
    bytes
        .chunks_exact(4)
        .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
        .collect()
}

fn read_u32_bytes(bytes: &[u8]) -> Vec<u32> {
    bytes
        .chunks_exact(4)
        .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
        .collect()
}

fn read_i32_bytes(bytes: &[u8]) -> Vec<i32> {
    bytes
        .chunks_exact(4)
        .map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
        .collect()
}

fn run_unary_f32(program: &vyre::ir::Program, input: &[f32]) -> Vec<f32> {
    let out = vec![0.0f32; input.len()];
    let result = interp::run(
        program,
        &[
            Value::Bytes(f32_bytes(input)),
            Value::Bytes(f32_bytes(&out)),
        ],
    )
    .unwrap();
    read_f32_bytes(&result[0].wide_bytes())
}

fn run_binary_f32(program: &vyre::ir::Program, a: &[f32], b: &[f32]) -> Vec<f32> {
    assert_eq!(a.len(), b.len());
    let out = vec![0.0f32; a.len()];
    let result = interp::run(
        program,
        &[
            Value::Bytes(f32_bytes(a)),
            Value::Bytes(f32_bytes(b)),
            Value::Bytes(f32_bytes(&out)),
        ],
    )
    .unwrap();
    read_f32_bytes(&result[0].wide_bytes())
}

fn run_binary_f32_to_u32(program: &vyre::ir::Program, a: &[f32], b: &[f32]) -> Vec<u32> {
    assert_eq!(a.len(), b.len());
    let out = vec![0u32; a.len()];
    let result = interp::run(
        program,
        &[
            Value::Bytes(f32_bytes(a)),
            Value::Bytes(f32_bytes(b)),
            Value::Bytes(u32_bytes(&out)),
        ],
    )
    .unwrap();
    read_u32_bytes(&result[0].wide_bytes())
}

fn run_binary_f32_to_i32(program: &vyre::ir::Program, a: &[f32], b: &[f32]) -> Vec<i32> {
    assert_eq!(a.len(), b.len());
    let out = vec![0i32; a.len()];
    let result = interp::run(
        program,
        &[
            Value::Bytes(f32_bytes(a)),
            Value::Bytes(f32_bytes(b)),
            Value::Bytes(i32_bytes(&out)),
        ],
    )
    .unwrap();
    read_i32_bytes(&result[0].wide_bytes())
}

fn run_unary_f32_to_u32(program: &vyre::ir::Program, input: &[f32]) -> Vec<u32> {
    let out = vec![0u32; input.len()];
    let result = interp::run(
        program,
        &[
            Value::Bytes(f32_bytes(input)),
            Value::Bytes(u32_bytes(&out)),
        ],
    )
    .unwrap();
    read_u32_bytes(&result[0].wide_bytes())
}

#[test]
fn f32_add_reference() {
    let program = f32_add::F32Add::program();
    assert_eq!(
        run_binary_f32(&program, &[1.0, 2.0], &[3.0, 4.0]),
        &[4.0, 6.0]
    );
}

#[test]
fn f32_sub_reference() {
    let program = f32_sub::F32Sub::program();
    assert_eq!(
        run_binary_f32(&program, &[5.0, 3.0], &[2.0, 1.0]),
        &[3.0, 2.0]
    );
}

#[test]
fn f32_mul_reference() {
    let program = f32_mul::F32Mul::program();
    assert_eq!(
        run_binary_f32(&program, &[2.0, 3.0], &[3.0, 4.0]),
        &[6.0, 12.0]
    );
}

#[test]
fn f32_div_reference() {
    let program = f32_div::F32Div::program();
    assert_eq!(
        run_binary_f32(&program, &[6.0, 8.0], &[2.0, 4.0]),
        &[3.0, 2.0]
    );
}

#[test]
fn f32_abs_reference() {
    let program = f32_abs::F32Abs::program();
    assert_eq!(
        run_unary_f32(&program, &[-1.0, 2.0, -3.0]),
        &[1.0, 2.0, 3.0]
    );
}

#[test]
fn f32_neg_reference() {
    let program = f32_neg::F32Neg::program();
    assert_eq!(
        run_unary_f32(&program, &[1.0, -2.0, 0.0]),
        &[-1.0, 2.0, -0.0]
    );
}

#[test]
fn f32_cos_reference() {
    let program = f32_cos::F32Cos::program();
    let out = run_unary_f32(&program, &[0.0]);
    assert!((out[0] - 1.0).abs() < 1e-6, "cos(0) should be 1");
}

#[test]
fn f32_sin_reference() {
    let program = f32_sin::F32Sin::program();
    let out = run_unary_f32(&program, &[0.0]);
    assert!(out[0].abs() < 1e-6, "sin(0) should be 0");
}

#[test]
fn f32_sqrt_reference() {
    let program = f32_sqrt::F32Sqrt::program();
    assert_eq!(run_unary_f32(&program, &[4.0, 9.0]), &[2.0, 3.0]);
}

#[test]
fn f32_ceil_reference() {
    let program = f32_ceil::F32Ceil::program();
    assert_eq!(run_unary_f32(&program, &[1.1, -1.1]), &[2.0, -1.0]);
}

#[test]
fn f32_floor_reference() {
    let program = f32_floor::F32Floor::program();
    assert_eq!(run_unary_f32(&program, &[1.9, -1.1]), &[1.0, -2.0]);
}

#[test]
fn f32_round_reference() {
    let program = f32_round::F32Round::program();
    assert_eq!(
        run_unary_f32(&program, &[1.4, 1.6, -1.5]),
        &[1.0, 2.0, -2.0]
    );
}

#[test]
fn f32_trunc_reference() {
    let program = f32_trunc::F32Trunc::program();
    assert_eq!(run_unary_f32(&program, &[1.9, -1.9]), &[1.0, -1.0]);
}

#[test]
fn f32_min_reference() {
    let program = f32_min::F32Min::program();
    assert_eq!(
        run_binary_f32(&program, &[1.0, 3.0], &[2.0, 2.0]),
        &[1.0, 2.0]
    );
}

#[test]
fn f32_max_reference() {
    let program = f32_max::F32Max::program();
    assert_eq!(
        run_binary_f32(&program, &[1.0, 3.0], &[2.0, 2.0]),
        &[2.0, 3.0]
    );
}

#[test]
fn f32_clamp_reference() {
    let program = f32_clamp::F32Clamp::program();
    // clamp uses packed bounds: low=0.0, high=2.0 in second buffer
    assert_eq!(
        run_binary_f32(&program, &[-1.0, 1.0, 3.0], &[0.0, 0.0, 0.0]),
        &[0.0, 1.0, 2.0]
    );
}

#[test]
fn f32_sign_reference() {
    let program = f32_sign::F32Sign::program();
    assert_eq!(
        run_unary_f32(&program, &[-2.0, 0.0, 3.0]),
        &[-1.0, 0.0, 1.0]
    );
}

#[test]
fn f32_eq_reference() {
    let program = f32_eq::F32Eq::program();
    assert_eq!(
        run_binary_f32_to_u32(&program, &[1.0, 1.0, 2.0], &[1.0, 2.0, 2.0]),
        &[1, 0, 1]
    );
}

#[test]
fn f32_lt_reference() {
    let program = f32_lt::F32Lt::program();
    assert_eq!(
        run_binary_f32_to_u32(&program, &[1.0, 2.0, 1.0], &[2.0, 1.0, 1.0]),
        &[1, 0, 0]
    );
}

#[test]
fn f32_le_reference() {
    let program = f32_le::F32Le::program();
    assert_eq!(
        run_binary_f32_to_u32(&program, &[1.0, 2.0, 1.0], &[2.0, 1.0, 1.0]),
        &[1, 0, 1]
    );
}

#[test]
fn f32_cmp_reference() {
    let program = f32_cmp::F32Cmp::program();
    assert_eq!(
        run_binary_f32_to_i32(&program, &[1.0, 2.0, 1.0], &[2.0, 1.0, 1.0]),
        &[-1, 1, 0]
    );
}

#[test]
fn f32_is_nan_reference() {
    let program = f32_is_nan::F32IsNan::program();
    assert_eq!(run_unary_f32_to_u32(&program, &[f32::NAN, 1.0]), &[1, 0]);
}

#[test]
fn f32_is_inf_reference() {
    let program = f32_is_inf::F32IsInf::program();
    assert_eq!(
        run_unary_f32_to_u32(&program, &[f32::INFINITY, 1.0]),
        &[1, 0]
    );
}

#[test]
fn f32_is_finite_reference() {
    let program = f32_is_finite::F32IsFinite::program();
    assert_eq!(
        run_unary_f32_to_u32(&program, &[f32::INFINITY, 1.0]),
        &[0, 1]
    );
}