use scirs2_core::simd_ops::SimdUnifiedOps;
pub fn add_assign_f32(out: &mut [f32], rhs: &[f32]) {
<f32 as SimdUnifiedOps>::simd_add_inplace(out, rhs);
}
pub fn sub_assign_f32(out: &mut [f32], rhs: &[f32]) {
<f32 as SimdUnifiedOps>::simd_sub_inplace(out, rhs);
}
pub fn mul_assign_f32(out: &mut [f32], rhs: &[f32]) {
<f32 as SimdUnifiedOps>::simd_mul_inplace(out, rhs);
}
pub fn div_assign_f32(out: &mut [f32], rhs: &[f32]) {
<f32 as SimdUnifiedOps>::simd_div_inplace(out, rhs);
}
pub fn add_into_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
<f32 as SimdUnifiedOps>::simd_add_into(a, b, out);
}
pub fn sub_into_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
<f32 as SimdUnifiedOps>::simd_sub_into(a, b, out);
}
pub fn mul_into_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
<f32 as SimdUnifiedOps>::simd_mul_into(a, b, out);
}
pub fn div_into_f32(a: &[f32], b: &[f32], out: &mut [f32]) {
<f32 as SimdUnifiedOps>::simd_div_into(a, b, out);
}
pub fn relu_assign_f32(out: &mut [f32]) {
for x in out.iter_mut() {
if !x.is_nan() && *x < 0.0 {
*x = 0.0;
}
}
}
pub fn leaky_relu_assign_f32(out: &mut [f32], negative_slope: f32) {
for x in out.iter_mut() {
if !x.is_nan() && *x < 0.0 {
*x *= negative_slope;
}
}
}
pub fn clamp_assign_f32(out: &mut [f32], min_val: f32, max_val: f32) {
for x in out.iter_mut() {
if x.is_nan() {
continue;
}
if *x < min_val {
*x = min_val;
} else if *x > max_val {
*x = max_val;
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BinaryF32Op {
Add,
Sub,
Mul,
Div,
}
impl BinaryF32Op {
pub fn dispatch_into(self, a: &[f32], b: &[f32], out: &mut [f32]) {
match self {
BinaryF32Op::Add => add_into_f32(a, b, out),
BinaryF32Op::Sub => sub_into_f32(a, b, out),
BinaryF32Op::Mul => mul_into_f32(a, b, out),
BinaryF32Op::Div => div_into_f32(a, b, out),
}
}
pub fn dispatch_inplace(self, out: &mut [f32], rhs: &[f32]) {
match self {
BinaryF32Op::Add => add_assign_f32(out, rhs),
BinaryF32Op::Sub => sub_assign_f32(out, rhs),
BinaryF32Op::Mul => mul_assign_f32(out, rhs),
BinaryF32Op::Div => div_assign_f32(out, rhs),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
const TEST_SIZES: &[usize] = &[0, 1, 7, 8, 16, 1023, 1024, 4096, 65536];
fn make_vecs(n: usize) -> (Vec<f32>, Vec<f32>) {
let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.5 + 1.0).collect();
let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.3 + 0.1).collect();
(a, b)
}
#[test]
fn test_add_assign_parity_with_scalar() {
for &n in TEST_SIZES {
let (a, b) = make_vecs(n);
let mut out_simd = a.clone();
add_assign_f32(&mut out_simd, &b);
let out_scalar: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x + y).collect();
for (s, r) in out_simd.iter().zip(out_scalar.iter()) {
assert_relative_eq!(s, r, epsilon = 1e-6);
}
}
}
#[test]
fn test_sub_assign_parity_with_scalar() {
for &n in TEST_SIZES {
let (a, b) = make_vecs(n);
let mut out_simd = a.clone();
sub_assign_f32(&mut out_simd, &b);
let out_scalar: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
for (s, r) in out_simd.iter().zip(out_scalar.iter()) {
assert_relative_eq!(s, r, epsilon = 1e-6);
}
}
}
#[test]
fn test_mul_assign_parity_with_scalar() {
for &n in TEST_SIZES {
let (a, b) = make_vecs(n);
let mut out_simd = a.clone();
mul_assign_f32(&mut out_simd, &b);
let out_scalar: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x * y).collect();
for (s, r) in out_simd.iter().zip(out_scalar.iter()) {
assert_relative_eq!(s, r, epsilon = 1e-5);
}
}
}
#[test]
fn test_div_assign_parity_with_scalar() {
for &n in TEST_SIZES {
let (a, b) = make_vecs(n);
let mut out_simd = a.clone();
div_assign_f32(&mut out_simd, &b);
let out_scalar: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x / y).collect();
for (s, r) in out_simd.iter().zip(out_scalar.iter()) {
assert_relative_eq!(s, r, epsilon = 1e-5);
}
}
}
#[test]
fn test_add_into_f32_parity() {
for &n in TEST_SIZES {
let (a, b) = make_vecs(n);
let mut out = vec![0.0f32; n];
add_into_f32(&a, &b, &mut out);
for ((aa, bb), rr) in a.iter().zip(b.iter()).zip(out.iter()) {
assert_relative_eq!(*rr, *aa + *bb, epsilon = 1e-6);
}
}
}
#[test]
fn test_sub_into_f32_parity() {
for &n in TEST_SIZES {
let (a, b) = make_vecs(n);
let mut out = vec![0.0f32; n];
sub_into_f32(&a, &b, &mut out);
for ((aa, bb), rr) in a.iter().zip(b.iter()).zip(out.iter()) {
assert_relative_eq!(*rr, *aa - *bb, epsilon = 1e-6);
}
}
}
#[test]
fn test_mul_into_f32_parity() {
for &n in TEST_SIZES {
let (a, b) = make_vecs(n);
let mut out = vec![0.0f32; n];
mul_into_f32(&a, &b, &mut out);
for ((aa, bb), rr) in a.iter().zip(b.iter()).zip(out.iter()) {
assert_relative_eq!(*rr, *aa * *bb, epsilon = 1e-5);
}
}
}
#[test]
fn test_div_into_f32_parity() {
for &n in TEST_SIZES {
let (a, b) = make_vecs(n);
let mut out = vec![0.0f32; n];
div_into_f32(&a, &b, &mut out);
for ((aa, bb), rr) in a.iter().zip(b.iter()).zip(out.iter()) {
assert_relative_eq!(*rr, *aa / *bb, epsilon = 1e-5);
}
}
}
#[test]
fn test_div_by_zero_produces_inf() {
let a = vec![1.0f32, -1.0, 0.0];
let b = vec![0.0f32, 0.0, 0.0];
let mut out = vec![0.0f32; 3];
div_into_f32(&a, &b, &mut out);
assert!(out[0].is_infinite() && out[0] > 0.0);
assert!(out[1].is_infinite() && out[1] < 0.0);
assert!(out[2].is_nan(), "0.0 / 0.0 should be NaN, got {}", out[2]);
}
#[test]
fn test_relu_assign_edge_cases() {
let mut data = vec![
-1.0f32,
-0.0,
0.0,
0.5,
f32::NAN,
f32::INFINITY,
f32::NEG_INFINITY,
];
relu_assign_f32(&mut data);
assert_eq!(data[0], 0.0);
assert!(!data[1].is_sign_negative() || data[1] == 0.0);
assert_eq!(data[2], 0.0);
assert_eq!(data[3], 0.5);
assert!(
data[4].is_nan(),
"NaN should pass through relu, got {}",
data[4]
);
assert_eq!(data[5], f32::INFINITY);
assert_eq!(data[6], 0.0);
}
#[test]
fn test_relu_assign_large() {
for &n in TEST_SIZES {
let mut data: Vec<f32> = (0..n).map(|i| (i as f32) - (n as f32 / 2.0)).collect();
let expected: Vec<f32> = data
.iter()
.map(|&x| if x >= 0.0 { x } else { 0.0 })
.collect();
relu_assign_f32(&mut data);
for (got, exp) in data.iter().zip(expected.iter()) {
assert_relative_eq!(got, exp, epsilon = 1e-7);
}
}
}
#[test]
fn test_leaky_relu_assign() {
let slope = 0.01_f32;
let mut data = vec![-2.0f32, -1.0, 0.0, 1.0, f32::NAN, f32::NEG_INFINITY];
leaky_relu_assign_f32(&mut data, slope);
assert_relative_eq!(data[0], -0.02_f32, epsilon = 1e-7);
assert_relative_eq!(data[1], -0.01_f32, epsilon = 1e-7);
assert_eq!(data[2], 0.0);
assert_eq!(data[3], 1.0);
assert!(data[4].is_nan(), "NaN should pass through leaky relu");
assert!(data[5].is_infinite() && data[5] < 0.0);
}
#[test]
fn test_clamp_nan_passthrough() {
let mut data = vec![f32::NAN, -2.0, 0.5, 2.0];
clamp_assign_f32(&mut data, -1.0, 1.0);
assert!(
data[0].is_nan(),
"NaN should pass through clamp, got {}",
data[0]
);
assert_eq!(data[1], -1.0, "clamped to min");
assert_eq!(data[2], 0.5, "unchanged");
assert_eq!(data[3], 1.0, "clamped to max");
}
#[test]
fn test_clamp_edge_values() {
let mut data = vec![f32::NEG_INFINITY, f32::INFINITY, -1.0, 1.0, 0.0];
clamp_assign_f32(&mut data, -1.0, 1.0);
assert_eq!(data[0], -1.0);
assert_eq!(data[1], 1.0);
assert_eq!(data[2], -1.0);
assert_eq!(data[3], 1.0);
assert_eq!(data[4], 0.0);
}
#[test]
fn test_binary_f32op_dispatch_into() {
let a = vec![2.0f32, 4.0, 6.0, 8.0];
let b = vec![1.0f32, 2.0, 3.0, 4.0];
let mut out = vec![0.0f32; 4];
BinaryF32Op::Add.dispatch_into(&a, &b, &mut out);
assert_eq!(out, vec![3.0, 6.0, 9.0, 12.0]);
BinaryF32Op::Sub.dispatch_into(&a, &b, &mut out);
assert_eq!(out, vec![1.0, 2.0, 3.0, 4.0]);
BinaryF32Op::Mul.dispatch_into(&a, &b, &mut out);
assert_eq!(out, vec![2.0, 8.0, 18.0, 32.0]);
BinaryF32Op::Div.dispatch_into(&a, &b, &mut out);
assert_eq!(out, vec![2.0, 2.0, 2.0, 2.0]);
}
#[test]
fn test_binary_f32op_dispatch_inplace() {
let a = vec![2.0f32, 4.0, 6.0, 8.0];
let b = vec![1.0f32, 2.0, 3.0, 4.0];
let mut out = a.clone();
BinaryF32Op::Add.dispatch_inplace(&mut out, &b);
assert_eq!(out, vec![3.0, 6.0, 9.0, 12.0]);
let mut out = a.clone();
BinaryF32Op::Sub.dispatch_inplace(&mut out, &b);
assert_eq!(out, vec![1.0, 2.0, 3.0, 4.0]);
let mut out = a.clone();
BinaryF32Op::Mul.dispatch_inplace(&mut out, &b);
assert_eq!(out, vec![2.0, 8.0, 18.0, 32.0]);
let mut out = a.clone();
BinaryF32Op::Div.dispatch_inplace(&mut out, &b);
assert_eq!(out, vec![2.0, 2.0, 2.0, 2.0]);
}
#[test]
fn test_binary_f32op_empty_slices() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
let mut out: Vec<f32> = vec![];
BinaryF32Op::Add.dispatch_into(&a, &b, &mut out);
BinaryF32Op::Add.dispatch_inplace(&mut out, &b);
}
#[test]
fn test_binary_f32op_debug() {
let _ = format!("{:?}", BinaryF32Op::Add);
let _ = format!("{:?}", BinaryF32Op::Sub);
let _ = format!("{:?}", BinaryF32Op::Mul);
let _ = format!("{:?}", BinaryF32Op::Div);
}
}