use vyre::ops::primitive::float::*;
use vyre_conform::{reference::interp, spec::value::Value};
fn f32_bytes(values: &[f32]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_le_bytes()).collect()
}
fn u32_bytes(values: &[u32]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_le_bytes()).collect()
}
fn i32_bytes(values: &[i32]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_le_bytes()).collect()
}
fn read_f32_bytes(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
fn read_u32_bytes(bytes: &[u8]) -> Vec<u32> {
bytes
.chunks_exact(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
fn read_i32_bytes(bytes: &[u8]) -> Vec<i32> {
bytes
.chunks_exact(4)
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
fn run_unary_f32(program: &vyre::ir::Program, input: &[f32]) -> Vec<f32> {
let out = vec![0.0f32; input.len()];
let result = interp::run(
program,
&[
Value::Bytes(f32_bytes(input)),
Value::Bytes(f32_bytes(&out)),
],
)
.unwrap();
read_f32_bytes(&result[0].wide_bytes())
}
fn run_binary_f32(program: &vyre::ir::Program, a: &[f32], b: &[f32]) -> Vec<f32> {
assert_eq!(a.len(), b.len());
let out = vec![0.0f32; a.len()];
let result = interp::run(
program,
&[
Value::Bytes(f32_bytes(a)),
Value::Bytes(f32_bytes(b)),
Value::Bytes(f32_bytes(&out)),
],
)
.unwrap();
read_f32_bytes(&result[0].wide_bytes())
}
fn run_binary_f32_to_u32(program: &vyre::ir::Program, a: &[f32], b: &[f32]) -> Vec<u32> {
assert_eq!(a.len(), b.len());
let out = vec![0u32; a.len()];
let result = interp::run(
program,
&[
Value::Bytes(f32_bytes(a)),
Value::Bytes(f32_bytes(b)),
Value::Bytes(u32_bytes(&out)),
],
)
.unwrap();
read_u32_bytes(&result[0].wide_bytes())
}
fn run_binary_f32_to_i32(program: &vyre::ir::Program, a: &[f32], b: &[f32]) -> Vec<i32> {
assert_eq!(a.len(), b.len());
let out = vec![0i32; a.len()];
let result = interp::run(
program,
&[
Value::Bytes(f32_bytes(a)),
Value::Bytes(f32_bytes(b)),
Value::Bytes(i32_bytes(&out)),
],
)
.unwrap();
read_i32_bytes(&result[0].wide_bytes())
}
fn run_unary_f32_to_u32(program: &vyre::ir::Program, input: &[f32]) -> Vec<u32> {
let out = vec![0u32; input.len()];
let result = interp::run(
program,
&[
Value::Bytes(f32_bytes(input)),
Value::Bytes(u32_bytes(&out)),
],
)
.unwrap();
read_u32_bytes(&result[0].wide_bytes())
}
#[test]
fn f32_add_reference() {
let program = f32_add::F32Add::program();
assert_eq!(
run_binary_f32(&program, &[1.0, 2.0], &[3.0, 4.0]),
&[4.0, 6.0]
);
}
#[test]
fn f32_sub_reference() {
let program = f32_sub::F32Sub::program();
assert_eq!(
run_binary_f32(&program, &[5.0, 3.0], &[2.0, 1.0]),
&[3.0, 2.0]
);
}
#[test]
fn f32_mul_reference() {
let program = f32_mul::F32Mul::program();
assert_eq!(
run_binary_f32(&program, &[2.0, 3.0], &[3.0, 4.0]),
&[6.0, 12.0]
);
}
#[test]
fn f32_div_reference() {
let program = f32_div::F32Div::program();
assert_eq!(
run_binary_f32(&program, &[6.0, 8.0], &[2.0, 4.0]),
&[3.0, 2.0]
);
}
#[test]
fn f32_abs_reference() {
let program = f32_abs::F32Abs::program();
assert_eq!(
run_unary_f32(&program, &[-1.0, 2.0, -3.0]),
&[1.0, 2.0, 3.0]
);
}
#[test]
fn f32_neg_reference() {
let program = f32_neg::F32Neg::program();
assert_eq!(
run_unary_f32(&program, &[1.0, -2.0, 0.0]),
&[-1.0, 2.0, -0.0]
);
}
#[test]
fn f32_cos_reference() {
let program = f32_cos::F32Cos::program();
let out = run_unary_f32(&program, &[0.0]);
assert!((out[0] - 1.0).abs() < 1e-6, "cos(0) should be 1");
}
#[test]
fn f32_sin_reference() {
let program = f32_sin::F32Sin::program();
let out = run_unary_f32(&program, &[0.0]);
assert!(out[0].abs() < 1e-6, "sin(0) should be 0");
}
#[test]
fn f32_sqrt_reference() {
let program = f32_sqrt::F32Sqrt::program();
assert_eq!(run_unary_f32(&program, &[4.0, 9.0]), &[2.0, 3.0]);
}
#[test]
fn f32_ceil_reference() {
let program = f32_ceil::F32Ceil::program();
assert_eq!(run_unary_f32(&program, &[1.1, -1.1]), &[2.0, -1.0]);
}
#[test]
fn f32_floor_reference() {
let program = f32_floor::F32Floor::program();
assert_eq!(run_unary_f32(&program, &[1.9, -1.1]), &[1.0, -2.0]);
}
#[test]
fn f32_round_reference() {
let program = f32_round::F32Round::program();
assert_eq!(
run_unary_f32(&program, &[1.4, 1.6, -1.5]),
&[1.0, 2.0, -2.0]
);
}
#[test]
fn f32_trunc_reference() {
let program = f32_trunc::F32Trunc::program();
assert_eq!(run_unary_f32(&program, &[1.9, -1.9]), &[1.0, -1.0]);
}
#[test]
fn f32_min_reference() {
let program = f32_min::F32Min::program();
assert_eq!(
run_binary_f32(&program, &[1.0, 3.0], &[2.0, 2.0]),
&[1.0, 2.0]
);
}
#[test]
fn f32_max_reference() {
let program = f32_max::F32Max::program();
assert_eq!(
run_binary_f32(&program, &[1.0, 3.0], &[2.0, 2.0]),
&[2.0, 3.0]
);
}
#[test]
fn f32_clamp_reference() {
let program = f32_clamp::F32Clamp::program();
assert_eq!(
run_binary_f32(&program, &[-1.0, 1.0, 3.0], &[0.0, 0.0, 0.0]),
&[0.0, 1.0, 2.0]
);
}
#[test]
fn f32_sign_reference() {
let program = f32_sign::F32Sign::program();
assert_eq!(
run_unary_f32(&program, &[-2.0, 0.0, 3.0]),
&[-1.0, 0.0, 1.0]
);
}
#[test]
fn f32_eq_reference() {
let program = f32_eq::F32Eq::program();
assert_eq!(
run_binary_f32_to_u32(&program, &[1.0, 1.0, 2.0], &[1.0, 2.0, 2.0]),
&[1, 0, 1]
);
}
#[test]
fn f32_lt_reference() {
let program = f32_lt::F32Lt::program();
assert_eq!(
run_binary_f32_to_u32(&program, &[1.0, 2.0, 1.0], &[2.0, 1.0, 1.0]),
&[1, 0, 0]
);
}
#[test]
fn f32_le_reference() {
let program = f32_le::F32Le::program();
assert_eq!(
run_binary_f32_to_u32(&program, &[1.0, 2.0, 1.0], &[2.0, 1.0, 1.0]),
&[1, 0, 1]
);
}
#[test]
fn f32_cmp_reference() {
let program = f32_cmp::F32Cmp::program();
assert_eq!(
run_binary_f32_to_i32(&program, &[1.0, 2.0, 1.0], &[2.0, 1.0, 1.0]),
&[-1, 1, 0]
);
}
#[test]
fn f32_is_nan_reference() {
let program = f32_is_nan::F32IsNan::program();
assert_eq!(run_unary_f32_to_u32(&program, &[f32::NAN, 1.0]), &[1, 0]);
}
#[test]
fn f32_is_inf_reference() {
let program = f32_is_inf::F32IsInf::program();
assert_eq!(
run_unary_f32_to_u32(&program, &[f32::INFINITY, 1.0]),
&[1, 0]
);
}
#[test]
fn f32_is_finite_reference() {
let program = f32_is_finite::F32IsFinite::program();
assert_eq!(
run_unary_f32_to_u32(&program, &[f32::INFINITY, 1.0]),
&[0, 1]
);
}