use aprender::nn::functional::swiglu_scalar;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_gate_bounded(
x in proptest::collection::vec(-10.0f32..10.0, 1..32usize),
gate in proptest::collection::vec(-10.0f32..10.0, 1..32usize)
) {
let n = x.len().min(gate.len());
let x_tensor = aprender::autograd::Tensor::from_slice(&x[..n]);
let gate_tensor = aprender::autograd::Tensor::from_slice(&gate[..n]);
let result = aprender::nn::functional::swiglu(&x_tensor, &gate_tensor);
for i in 0..n {
let val = result.data()[i];
let bound = x[i].abs() * (gate[i].abs() + 0.28);
prop_assert!(
val.abs() <= bound + 1e-5,
"SwiGLU[{i}] = {val}, bound = {bound}"
);
}
}
#[test]
fn prop_zero_input(
gate in proptest::collection::vec(-10.0f32..10.0, 1..32usize)
) {
for &g in &gate {
let val = swiglu_scalar(0.0, g);
prop_assert!(
val.abs() < 1e-7,
"SwiGLU(0, {g}) = {val}, expected 0.0"
);
}
}
#[test]
fn prop_finite_output(
x in proptest::collection::vec(-10.0f32..10.0, 1..32usize),
gate in proptest::collection::vec(-10.0f32..10.0, 1..32usize)
) {
let n = x.len().min(gate.len());
for i in 0..n {
let val = swiglu_scalar(x[i], gate[i]);
prop_assert!(
val.is_finite(),
"SwiGLU({}, {}) = {val} — not finite",
x[i], gate[i]
);
}
}
#[test]
fn prop_simd_equivalence(
x in proptest::collection::vec(-10.0f32..10.0, 1..32usize),
gate in proptest::collection::vec(-10.0f32..10.0, 1..32usize)
) {
let n = x.len().min(gate.len());
let x_tensor = aprender::autograd::Tensor::from_slice(&x[..n]);
let gate_tensor = aprender::autograd::Tensor::from_slice(&gate[..n]);
let result = aprender::nn::functional::swiglu(&x_tensor, &gate_tensor);
for i in 0..n {
let expected = swiglu_scalar(x[i], gate[i]);
let actual = result.data()[i];
let diff = (actual - expected).abs();
prop_assert!(
diff < 1e-4,
"Tensor vs scalar mismatch at [{i}]: tensor={actual}, scalar={expected}, diff={diff}"
);
}
}
}