use vyre::ir::{Expr, Program};
const OP_ID: &str = "vyre-libs::nn::leaky_relu_sq";
fn leaky_relu_sq_expr(x: Expr) -> Expr {
let half_x = Expr::mul(Expr::f32(0.5), x.clone());
let leaky = Expr::max(half_x, x);
Expr::mul(leaky.clone(), leaky)
}
#[must_use]
pub fn leaky_relu_sq(input: &str, output: &str, n: u32) -> Program {
super::unary::f32_unary_activation_program(OP_ID, input, output, n, leaky_relu_sq_expr)
}
inventory::submit! {
crate::harness::OpEntry {
id: OP_ID,
build: || leaky_relu_sq("input", "output", 4),
test_inputs: Some(|| {
let to_bytes = vyre_primitives::wire::pack_f32_slice;
vec![vec![
to_bytes(&[0.0_f32, 2.0, -4.0, 1.0]),
]]
}),
expected_output: Some(|| {
let input = [0.0_f32, 2.0, -4.0, 1.0];
let out: Vec<f32> = input.iter().map(|x| {
let leaky = (0.5 * x).max(*x);
leaky * leaky
}).collect();
let bytes = vyre_primitives::wire::pack_f32_slice(&out);
vec![vec![bytes]]
}),
category: Some("nn"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::decode_f32;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
fn leaky_relu_sq_ref(x: f32) -> f32 {
let leaky = (0.5 * x).max(x);
leaky * leaky
}
#[test]
fn leaky_relu_sq_nan_input_propagates_nan() {
let input = [f32::NAN];
let program = leaky_relu_sq("input", "output", 1);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 4])],
)
.expect("Fix: leaky_relu_sq must not panic on NaN input");
let out = decode_f32(&outputs[0].to_bytes());
assert!(out[0].is_nan(), "leaky_relu_sq(NaN) must be NaN");
}
#[test]
fn leaky_relu_sq_inf_inputs() {
let program = leaky_relu_sq("input", "output", 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&[f32::INFINITY, 0.0])),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: leaky_relu_sq must not panic on +Inf input");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out[0], f32::INFINITY, "leaky_relu_sq(+Inf) must be +Inf");
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&[f32::NEG_INFINITY, 0.0])),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: leaky_relu_sq must not panic on -Inf input");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(
out[0],
f32::INFINITY,
"leaky_relu_sq(-Inf) must be +Inf (square of negative infinity)"
);
}
#[test]
fn leaky_relu_sq_negative_zero_vs_positive_zero() {
let program = leaky_relu_sq("input", "output", 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&[0.0f32, -0.0f32])),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: leaky_relu_sq must handle -0.0");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out[0].to_bits(), 0.0f32.to_bits());
assert_eq!(
out[1].to_bits(),
0.0f32.to_bits(),
"leaky_relu_sq(-0.0) must be +0.0"
);
}
#[test]
fn leaky_relu_sq_subnormal_input() {
let sub = f32::from_bits(1);
let program = leaky_relu_sq("input", "output", 1);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&[sub])), Value::from(vec![0u8; 4])],
)
.expect("Fix: leaky_relu_sq must not panic on subnormal input");
let out = decode_f32(&outputs[0].to_bytes());
let expected = leaky_relu_sq_ref(sub);
assert!(
(out[0] - expected).abs() <= 1.0e-6,
"leaky_relu_sq(subnormal) mismatch"
);
}
#[test]
fn generated_leaky_relu_sq_matches_scalar_reference() {
let input = (0..2048u32)
.map(|i| ((i as f32) * 0.031).cos() * 8.0 - 4.0)
.collect::<Vec<_>>();
let program = leaky_relu_sq("input", "output", input.len() as u32);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&input)),
Value::from(vec![0u8; input.len() * core::mem::size_of::<f32>()]),
],
)
.expect("Fix: generated leaky_relu_sq corpus must execute");
let out = decode_f32(&outputs[0].to_bytes());
for (index, (actual, expected)) in out
.iter()
.copied()
.zip(input.iter().copied().map(leaky_relu_sq_ref))
.enumerate()
{
assert!(
(actual - expected).abs() <= 1.0e-5,
"generated leaky_relu_sq mismatch at {index}: {actual} != {expected}"
);
}
}
#[test]
fn leaky_relu_sq_all_zeros() {
let input = [0.0f32; 4];
let program = leaky_relu_sq("input", "output", 4);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 16])],
)
.expect("Fix: leaky_relu_sq all-zeros must execute");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out, vec![0.0; 4]);
}
#[test]
fn leaky_relu_sq_all_ones() {
let input = [1.0f32; 4];
let program = leaky_relu_sq("input", "output", 4);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 16])],
)
.expect("Fix: leaky_relu_sq all-ones must execute");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out, vec![1.0; 4]);
}
#[test]
fn leaky_relu_sq_all_max_f32() {
let input = [f32::MAX; 4];
let program = leaky_relu_sq("input", "output", 4);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 16])],
)
.expect("Fix: leaky_relu_sq all-max-f32 must not panic");
let out = decode_f32(&outputs[0].to_bytes());
for (i, &v) in out.iter().enumerate() {
assert_eq!(
v,
f32::INFINITY,
"leaky_relu_sq(f32::MAX) must overflow to +Inf at {i}: got {v}"
);
}
}
#[test]
fn leaky_relu_sq_single_element() {
let input = [-3.0f32];
let program = leaky_relu_sq("input", "output", 1);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 4])],
)
.expect("Fix: leaky_relu_sq single element must execute");
let out = decode_f32(&outputs[0].to_bytes());
let expected = leaky_relu_sq_ref(-3.0);
assert!(
(out[0] - expected).abs() <= 1.0e-5,
"leaky_relu_sq single element mismatch"
);
}
#[test]
fn leaky_relu_sq_empty_tensor() {
let program = leaky_relu_sq("input", "output", 0);
let outputs =
vyre_reference::reference_eval(&program, &[Value::from(vec![]), Value::from(vec![])])
.expect("Fix: leaky_relu_sq n=0 must not panic");
assert!(outputs[0].to_bytes().is_empty());
}
use proptest::prelude::*;
proptest! {
#[test]
fn leaky_relu_sq_output_is_nonnegative(x in prop::num::f32::NORMAL) {
let program = leaky_relu_sq("input", "output", 1);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&[x])), Value::from(vec![0u8; 4])],
)
.expect("Fix: leaky_relu_sq must not panic on finite input");
let out = decode_f32(&outputs[0].to_bytes())[0];
if x.is_nan() {
prop_assert!(out.is_nan());
} else {
prop_assert!(out >= 0.0 || out.is_nan(), "leaky_relu_sq(x) must be >= 0 or NaN, got {out}");
}
}
}
}