use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
const OP_ID: &str = "vyre-libs::math::weighted_sum_fma_f32";
pub fn weighted_sum_fma_f32(
weights: &str,
values: &str,
output: &str,
n: u32,
) -> Result<Program, String> {
if n == 0 {
return Err("Fix: weighted_sum_fma_f32 n=0 is invalid: empty reduction".to_string());
}
let body = vec![
Node::let_bind("acc", Expr::f32(0.0)),
Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(n),
vec![Node::assign(
"acc",
Expr::Fma {
a: Box::new(Expr::load(weights, Expr::var("i"))),
b: Box::new(Expr::load(values, Expr::var("i"))),
c: Box::new(Expr::var("acc")),
},
)],
),
Node::Store {
buffer: output.into(),
index: Expr::u32(0),
value: Expr::var("acc"),
},
];
Ok(Program::wrapped(
vec![
BufferDecl::storage(weights, 0, BufferAccess::ReadOnly, DataType::F32).with_count(n),
BufferDecl::storage(values, 1, BufferAccess::ReadOnly, DataType::F32).with_count(n),
BufferDecl::output(output, 2, DataType::F32).with_count(1),
],
[1, 1, 1],
vec![wrap_anonymous(OP_ID, body)],
))
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || {
weighted_sum_fma_f32("weights", "values", "output", 4).unwrap_or_else(|error| {
crate::builder::invalid_output_program(
OP_ID,
"output",
DataType::F32,
error,
)
})
},
test_inputs: Some(|| {
let weights = crate::test_support::byte_pack::f32_bytes(&[0.5, 0.25, 0.125, 0.125]);
let values = crate::test_support::byte_pack::f32_bytes(&[1.0, 2.0, 4.0, 8.0]);
vec![vec![weights, values]]
}),
expected_output: Some(|| {
vec![vec![crate::test_support::byte_pack::f32_bytes(&[2.5_f32])]]
}),
category: Some("math"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::decode_f32_one as decode_one;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
fn run(weights: &[f32], values: &[f32]) -> f32 {
let n = weights.len() as u32;
let prog = weighted_sum_fma_f32("weights", "values", "output", n).expect("Fix: build");
let outputs = vyre_reference::reference_eval(
&prog,
&[
Value::from(f32_bytes(weights)),
Value::from(f32_bytes(values)),
Value::from(vec![0u8; 4]),
],
)
.expect("Fix: weighted_sum_fma_f32 must execute in the reference interpreter.");
decode_one(&outputs[0].to_bytes())
}
#[test]
fn weighted_sum_fma_canonical_fixture() {
let weights = [0.5_f32, 0.25, 0.125, 0.125];
let values = [1.0_f32, 2.0, 4.0, 8.0];
let actual = run(&weights, &values);
assert!((actual - 2.5).abs() <= 1.0e-6, "{actual} != 2.5");
}
#[test]
fn weighted_sum_fma_zero_weights_returns_zero() {
let weights = [0.0_f32; 4];
let values = [1.0_f32, 2.0, 3.0, 4.0];
let actual = run(&weights, &values);
assert_eq!(actual, 0.0);
}
#[test]
fn weighted_sum_fma_unit_weights_equals_plain_sum() {
let weights = [1.0_f32; 8];
let values: Vec<f32> = (0..8).map(|i| i as f32 - 3.5).collect();
let actual = run(&weights, &values);
let expected: f32 = values.iter().sum();
assert!(
(actual - expected).abs() <= 1.0e-5,
"{actual} != {expected}"
);
}
#[test]
fn weighted_sum_fma_matches_naive_on_random_fuzz() {
let mut state = 0xFEEDBEEF_u64;
let mut next = || {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32 / (u32::MAX as f32 / 2.0)) - 1.0
};
for _ in 0..50 {
let n = 8;
let weights: Vec<f32> = (0..n).map(|_| next()).collect();
let values: Vec<f32> = (0..n).map(|_| next()).collect();
let actual = run(&weights, &values);
let expected: f32 = weights.iter().zip(values.iter()).map(|(w, v)| w * v).sum();
assert!(
(actual - expected).abs() <= 1.0e-4,
"fma={actual} naive={expected} diff={}",
(actual - expected).abs()
);
}
}
#[test]
fn weighted_sum_fma_rejects_empty_n() {
let err =
weighted_sum_fma_f32("weights", "values", "output", 0).expect_err("empty n must error");
assert!(err.contains("n=0"));
}
#[test]
fn weighted_sum_fma_single_element() {
let weights = [3.0_f32];
let values = [4.0_f32];
let actual = run(&weights, &values);
assert!((actual - 12.0).abs() <= 1.0e-5, "3*4 = 12, got {actual}");
}
#[test]
fn weighted_sum_fma_nan_in_weights_propagates() {
let weights = [1.0_f32, f32::NAN, 1.0];
let values = [1.0_f32, 1.0, 1.0];
let actual = run(&weights, &values);
assert!(
actual.is_nan(),
"weighted sum with NaN weight must be NaN, got {actual}"
);
}
#[test]
fn weighted_sum_fma_nan_in_values_propagates() {
let weights = [1.0_f32, 1.0, 1.0];
let values = [1.0_f32, f32::NAN, 1.0];
let actual = run(&weights, &values);
assert!(
actual.is_nan(),
"weighted sum with NaN value must be NaN, got {actual}"
);
}
#[test]
fn weighted_sum_fma_inf_in_weights_propagates() {
let weights = [1.0_f32, f32::INFINITY, 1.0];
let values = [1.0_f32, 1.0, 1.0];
let actual = run(&weights, &values);
assert!(
actual.is_infinite(),
"weighted sum with Inf weight must be Inf, got {actual}"
);
}
#[test]
fn weighted_sum_fma_inf_in_values_propagates() {
let weights = [1.0_f32, 1.0, 1.0];
let values = [1.0_f32, f32::INFINITY, 1.0];
let actual = run(&weights, &values);
assert!(
actual.is_infinite(),
"weighted sum with Inf value must be Inf, got {actual}"
);
}
}