#[cfg(target_arch = "x86_64")]
mod avx2;
#[cfg(target_arch = "x86_64")]
mod avx512;
#[cfg(target_arch = "aarch64")]
mod aarch64;
use super::{SimdLevel, detect_simd};
use crate::ops::BinaryOp;
const SIMD_THRESHOLD: usize = 32;
#[inline]
const fn is_simd_supported(op: BinaryOp) -> bool {
matches!(
op,
BinaryOp::Add
| BinaryOp::Sub
| BinaryOp::Mul
| BinaryOp::Div
| BinaryOp::Max
| BinaryOp::Min
)
}
#[inline]
pub unsafe fn scalar_f32(op: BinaryOp, a: *const f32, scalar: f32, out: *mut f32, len: usize) {
let level = detect_simd();
if len < SIMD_THRESHOLD || level == SimdLevel::Scalar || !is_simd_supported(op) {
scalar_scalar_f32(op, a, scalar, out, len);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::scalar_f32(op, a, scalar, out, len),
SimdLevel::Avx2Fma => avx2::scalar_f32(op, a, scalar, out, len),
_ => scalar_scalar_f32(op, a, scalar, out, len),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::scalar_f32(op, a, scalar, out, len),
_ => scalar_scalar_f32(op, a, scalar, out, len),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
scalar_scalar_f32(op, a, scalar, out, len);
}
#[inline]
pub unsafe fn scalar_f64(op: BinaryOp, a: *const f64, scalar: f64, out: *mut f64, len: usize) {
let level = detect_simd();
if len < SIMD_THRESHOLD || level == SimdLevel::Scalar || !is_simd_supported(op) {
scalar_scalar_f64(op, a, scalar, out, len);
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::scalar_f64(op, a, scalar, out, len),
SimdLevel::Avx2Fma => avx2::scalar_f64(op, a, scalar, out, len),
_ => scalar_scalar_f64(op, a, scalar, out, len),
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => aarch64::neon::scalar_f64(op, a, scalar, out, len),
_ => scalar_scalar_f64(op, a, scalar, out, len),
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
scalar_scalar_f64(op, a, scalar, out, len);
}
#[inline]
pub unsafe fn scalar_scalar_f32(
op: BinaryOp,
a: *const f32,
scalar: f32,
out: *mut f32,
len: usize,
) {
match op {
BinaryOp::Add => {
for i in 0..len {
*out.add(i) = *a.add(i) + scalar;
}
}
BinaryOp::Sub => {
for i in 0..len {
*out.add(i) = *a.add(i) - scalar;
}
}
BinaryOp::Mul => {
for i in 0..len {
*out.add(i) = *a.add(i) * scalar;
}
}
BinaryOp::Div => {
for i in 0..len {
*out.add(i) = *a.add(i) / scalar;
}
}
BinaryOp::Max => {
for i in 0..len {
let v = *a.add(i);
*out.add(i) = if v > scalar { v } else { scalar };
}
}
BinaryOp::Min => {
for i in 0..len {
let v = *a.add(i);
*out.add(i) = if v < scalar { v } else { scalar };
}
}
BinaryOp::Pow => {
for i in 0..len {
*out.add(i) = (*a.add(i)).powf(scalar);
}
}
BinaryOp::Atan2 => {
for i in 0..len {
*out.add(i) = (*a.add(i)).atan2(scalar);
}
}
}
}
#[inline]
pub unsafe fn scalar_scalar_f64(
op: BinaryOp,
a: *const f64,
scalar: f64,
out: *mut f64,
len: usize,
) {
match op {
BinaryOp::Add => {
for i in 0..len {
*out.add(i) = *a.add(i) + scalar;
}
}
BinaryOp::Sub => {
for i in 0..len {
*out.add(i) = *a.add(i) - scalar;
}
}
BinaryOp::Mul => {
for i in 0..len {
*out.add(i) = *a.add(i) * scalar;
}
}
BinaryOp::Div => {
for i in 0..len {
*out.add(i) = *a.add(i) / scalar;
}
}
BinaryOp::Max => {
for i in 0..len {
let v = *a.add(i);
*out.add(i) = if v > scalar { v } else { scalar };
}
}
BinaryOp::Min => {
for i in 0..len {
let v = *a.add(i);
*out.add(i) = if v < scalar { v } else { scalar };
}
}
BinaryOp::Pow => {
for i in 0..len {
*out.add(i) = (*a.add(i)).powf(scalar);
}
}
BinaryOp::Atan2 => {
for i in 0..len {
*out.add(i) = (*a.add(i)).atan2(scalar);
}
}
}
}
#[inline]
pub unsafe fn rsub_scalar_f32(a: *const f32, scalar: f32, out: *mut f32, len: usize) {
let level = detect_simd();
if len < SIMD_THRESHOLD || level == SimdLevel::Scalar {
for i in 0..len {
*out.add(i) = scalar - *a.add(i);
}
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::rsub_scalar_f32(a, scalar, out, len),
SimdLevel::Avx2Fma => avx2::rsub_scalar_f32(a, scalar, out, len),
_ => {
for i in 0..len {
*out.add(i) = scalar - *a.add(i);
}
}
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::rsub_scalar_f32(a, scalar, out, len)
}
_ => {
for i in 0..len {
*out.add(i) = scalar - *a.add(i);
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
for i in 0..len {
*out.add(i) = scalar - *a.add(i);
}
}
}
#[inline]
pub unsafe fn rsub_scalar_f64(a: *const f64, scalar: f64, out: *mut f64, len: usize) {
let level = detect_simd();
if len < SIMD_THRESHOLD || level == SimdLevel::Scalar {
for i in 0..len {
*out.add(i) = scalar - *a.add(i);
}
return;
}
#[cfg(target_arch = "x86_64")]
match level {
SimdLevel::Avx512 => avx512::rsub_scalar_f64(a, scalar, out, len),
SimdLevel::Avx2Fma => avx2::rsub_scalar_f64(a, scalar, out, len),
_ => {
for i in 0..len {
*out.add(i) = scalar - *a.add(i);
}
}
}
#[cfg(target_arch = "aarch64")]
match level {
SimdLevel::Neon | SimdLevel::NeonFp16 => {
aarch64::neon::rsub_scalar_f64(a, scalar, out, len)
}
_ => {
for i in 0..len {
*out.add(i) = scalar - *a.add(i);
}
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
for i in 0..len {
*out.add(i) = scalar - *a.add(i);
}
}
}
half_scalar_op!(scalar, scalar_f32, BinaryOp);
half_unary_scalar!(rsub_scalar, rsub_scalar_f32);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scalar_add_f32() {
let a: Vec<f32> = (0..100).map(|x| x as f32).collect();
let mut out = vec![0.0f32; 100];
unsafe { scalar_f32(BinaryOp::Add, a.as_ptr(), 10.0, out.as_mut_ptr(), 100) }
for i in 0..100 {
assert_eq!(out[i], a[i] + 10.0, "mismatch at index {}", i);
}
}
#[test]
fn test_scalar_mul_f32() {
let a: Vec<f32> = (0..100).map(|x| x as f32).collect();
let mut out = vec![0.0f32; 100];
unsafe { scalar_f32(BinaryOp::Mul, a.as_ptr(), 2.5, out.as_mut_ptr(), 100) }
for i in 0..100 {
assert!(
(out[i] - a[i] * 2.5).abs() < 1e-6,
"mismatch at index {}",
i
);
}
}
#[test]
fn test_scalar_max_f32() {
let a: Vec<f32> = (0..100).map(|x| (x as f32) - 50.0).collect();
let mut out = vec![0.0f32; 100];
unsafe { scalar_f32(BinaryOp::Max, a.as_ptr(), 0.0, out.as_mut_ptr(), 100) }
for i in 0..100 {
let expected = if a[i] > 0.0 { a[i] } else { 0.0 };
assert_eq!(out[i], expected, "mismatch at index {}", i);
}
}
#[test]
fn test_scalar_div_f64() {
let a: Vec<f64> = (1..101).map(|x| x as f64).collect();
let mut out = vec![0.0f64; 100];
unsafe { scalar_f64(BinaryOp::Div, a.as_ptr(), 2.0, out.as_mut_ptr(), 100) }
for i in 0..100 {
assert!(
(out[i] - a[i] / 2.0).abs() < 1e-10,
"mismatch at index {}",
i
);
}
}
}