use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program, UnOp};
use crate::region::wrap_anonymous;
use vyre_primitives::nn::f32_stability::flush_tiny;
pub(crate) fn silu_expr(x: Expr) -> Expr {
let sigmoid_x = Expr::div(
Expr::f32(1.0),
Expr::add(
Expr::f32(1.0),
Expr::UnOp {
op: UnOp::Exp,
operand: Box::new(Expr::UnOp {
op: UnOp::Negate,
operand: Box::new(x.clone()),
}),
},
),
);
flush_tiny(Expr::mul(x, sigmoid_x))
}
#[must_use]
pub fn silu(input: &str, output: &str, n: u32) -> Program {
let i = Expr::var("i");
let x = Expr::load(input, i.clone());
let body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::buf_len(input)),
vec![Node::Store {
buffer: output.into(),
index: i,
value: silu_expr(x),
}],
),
];
Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::F32).with_count(n),
BufferDecl::output(output, 1, DataType::F32).with_count(n),
],
[64, 1, 1],
vec![wrap_anonymous("vyre-libs::nn::silu", body)],
)
}
inventory::submit! {
crate::harness::OpEntry {
id: "vyre-libs::nn::silu",
build: || silu("input", "output", 4),
test_inputs: Some(|| {
let to_bytes = vyre_primitives::wire::pack_f32_slice;
vec![vec![
to_bytes(&[0.0_f32, 1.0, -1.0, 2.0]), ]]
}),
expected_output: Some(|| {
let input = [0.0_f32, 1.0, -1.0, 2.0];
let out: Vec<f32> = input
.iter()
.map(|x| x / (1.0 + (-x).exp()))
.collect();
let bytes = vyre_primitives::wire::pack_f32_slice(&out);
vec![vec![bytes]]
}),
category: Some("nn"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::byte_pack::decode_f32;
use crate::test_support::byte_pack::f32_bytes;
use vyre_reference::value::Value;
fn silu_ref(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
#[test]
fn silu_nan_input_propagates_nan() {
let input = [f32::NAN];
let program = silu("input", "output", 1);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 4])],
)
.expect("Fix: silu must not panic on NaN input");
let out = decode_f32(&outputs[0].to_bytes());
assert!(out[0].is_nan(), "silu(NaN) must be NaN");
}
#[test]
fn silu_inf_inputs() {
let program = silu("input", "output", 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&[f32::INFINITY, 0.0])),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: silu must not panic on +Inf input");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out[0], f32::INFINITY, "silu(+Inf) must be +Inf");
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&[f32::NEG_INFINITY, 0.0])),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: silu must not panic on -Inf input");
let out = decode_f32(&outputs[0].to_bytes());
assert!(
out[0].is_nan(),
"silu(-Inf) must be NaN (negative infinity times zero)"
);
}
#[test]
fn silu_negative_zero_vs_positive_zero() {
let program = silu("input", "output", 2);
let outputs = vyre_reference::reference_eval(
&program,
&[
Value::from(f32_bytes(&[0.0f32, -0.0f32])),
Value::from(vec![0u8; 8]),
],
)
.expect("Fix: silu must distinguish -0.0 from 0.0");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out[0].to_bits(), 0.0f32.to_bits());
assert!(out[1] == 0.0, "silu(-0.0) must be zero, got {}", out[1]);
}
#[test]
fn silu_subnormal_input_is_flushed_to_zero() {
let sub = f32::from_bits(1); let program = silu("input", "output", 1);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&[sub])), Value::from(vec![0u8; 4])],
)
.expect("Fix: silu must not panic on subnormal input");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(
out[0].to_bits(),
0.0f32.to_bits(),
"silu must flush tiny subnormal to +0.0"
);
}
#[test]
fn silu_all_zeros() {
let input = [0.0f32; 4];
let program = silu("input", "output", 4);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 16])],
)
.expect("Fix: silu all-zeros must execute");
let out = decode_f32(&outputs[0].to_bytes());
assert_eq!(out, vec![0.0; 4]);
}
#[test]
fn silu_all_ones() {
let input = [1.0f32; 4];
let program = silu("input", "output", 4);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 16])],
)
.expect("Fix: silu all-ones must execute");
let out = decode_f32(&outputs[0].to_bytes());
let expected = silu_ref(1.0);
for (i, &v) in out.iter().enumerate() {
assert!(
(v - expected).abs() <= 1.0e-6,
"silu all-ones mismatch at {i}: {v}"
);
}
}
#[test]
fn silu_all_max_f32() {
let input = [f32::MAX; 4];
let program = silu("input", "output", 4);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 16])],
)
.expect("Fix: silu all-max-f32 must not panic");
let out = decode_f32(&outputs[0].to_bytes());
for (i, &v) in out.iter().enumerate() {
assert_eq!(
v,
f32::MAX,
"silu(f32::MAX) must be f32::MAX at {i}: got {v}"
);
}
}
#[test]
fn silu_single_element() {
let input = [2.5f32];
let program = silu("input", "output", 1);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&input)), Value::from(vec![0u8; 4])],
)
.expect("Fix: silu single element must execute");
let out = decode_f32(&outputs[0].to_bytes());
let expected = silu_ref(2.5);
assert!(
(out[0] - expected).abs() <= 1.0e-6,
"silu single element mismatch: {} != {}",
out[0],
expected
);
}
#[test]
fn silu_empty_tensor() {
let program = silu("input", "output", 0);
let outputs =
vyre_reference::reference_eval(&program, &[Value::from(vec![]), Value::from(vec![])])
.expect("Fix: silu n=0 must not panic");
assert!(outputs[0].to_bytes().is_empty());
}
use proptest::prelude::*;
proptest! {
#[test]
fn silu_output_invariant_for_finite_inputs(x in -1e10f32..1e10f32) {
let program = silu("input", "output", 1);
let outputs = vyre_reference::reference_eval(
&program,
&[Value::from(f32_bytes(&[x])), Value::from(vec![0u8; 4])],
)
.expect("Fix: silu must not panic on finite input");
let out = decode_f32(&outputs[0].to_bytes())[0];
if x.is_nan() {
prop_assert!(out.is_nan());
} else if x > 0.0 {
prop_assert!(out > 0.0 && out <= x, "silu(x) for x>0 must be in (0, x]");
} else if x < 0.0 {
prop_assert!(out >= x && out <= 0.0, "silu(x) for x<0 must be in [x, 0]");
} else {
prop_assert_eq!(out.to_bits(), 0.0f32.to_bits());
}
}
}
}