use pulp::Arch;
use std::sync::atomic::{AtomicU8, Ordering};
static FORCE_SCALAR_CACHED: AtomicU8 = AtomicU8::new(0);
#[inline]
pub fn force_scalar() -> bool {
let val = FORCE_SCALAR_CACHED.load(Ordering::Relaxed);
if val != 0 {
return val == 2;
}
let result = std::env::var("FERRAY_FORCE_SCALAR")
.ok()
.is_some_and(|v| v == "1");
FORCE_SCALAR_CACHED.store(if result { 2 } else { 1 }, Ordering::Relaxed);
result
}
pub fn reset_force_scalar() {
FORCE_SCALAR_CACHED.store(0, Ordering::SeqCst);
}
#[inline]
pub fn dispatch_unary_f32(input: &[f32], output: &mut [f32], scalar_fn: fn(f32) -> f32) {
debug_assert_eq!(input.len(), output.len());
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = scalar_fn(i);
}
}
#[inline]
pub fn dispatch_unary_f64(input: &[f64], output: &mut [f64], scalar_fn: fn(f64) -> f64) {
debug_assert_eq!(input.len(), output.len());
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = scalar_fn(i);
}
}
#[inline]
pub fn dispatch_binary_f32(
a: &[f32],
b: &[f32],
output: &mut [f32],
scalar_fn: fn(f32, f32) -> f32,
) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), output.len());
for ((o, &ai), &bi) in output.iter_mut().zip(a.iter()).zip(b.iter()) {
*o = scalar_fn(ai, bi);
}
}
#[inline]
pub fn dispatch_binary_f64(
a: &[f64],
b: &[f64],
output: &mut [f64],
scalar_fn: fn(f64, f64) -> f64,
) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), output.len());
for ((o, &ai), &bi) in output.iter_mut().zip(a.iter()).zip(b.iter()) {
*o = scalar_fn(ai, bi);
}
}
#[cfg(feature = "f16")]
#[inline]
pub fn dispatch_unary_f16(
input: &[half::f16],
output: &mut [half::f16],
scalar_fn: fn(f32) -> f32,
) {
debug_assert_eq!(input.len(), output.len());
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = half::f16::from_f32(scalar_fn(i.to_f32()));
}
}
#[cfg(feature = "f16")]
#[inline]
pub fn dispatch_binary_f16(
a: &[half::f16],
b: &[half::f16],
output: &mut [half::f16],
scalar_fn: fn(f32, f32) -> f32,
) {
debug_assert_eq!(a.len(), b.len());
debug_assert_eq!(a.len(), output.len());
for ((o, &ai), &bi) in output.iter_mut().zip(a.iter()).zip(b.iter()) {
*o = half::f16::from_f32(scalar_fn(ai.to_f32(), bi.to_f32()));
}
}
#[inline]
pub fn simd_sqrt_f64(input: &[f64], output: &mut [f64]) {
debug_assert_eq!(input.len(), output.len());
if force_scalar() {
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = i.sqrt();
}
} else {
let arch = Arch::new();
arch.dispatch(SqrtF64Op { input, output });
}
}
#[inline]
pub fn simd_sqrt_f32(input: &[f32], output: &mut [f32]) {
debug_assert_eq!(input.len(), output.len());
if force_scalar() {
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = i.sqrt();
}
} else {
let arch = Arch::new();
arch.dispatch(SqrtF32Op { input, output });
}
}
#[inline]
pub fn simd_abs_f64(input: &[f64], output: &mut [f64]) {
debug_assert_eq!(input.len(), output.len());
if force_scalar() {
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = i.abs();
}
} else {
let arch = Arch::new();
arch.dispatch(AbsF64Op { input, output });
}
}
#[inline]
pub fn simd_neg_f64(input: &[f64], output: &mut [f64]) {
debug_assert_eq!(input.len(), output.len());
if force_scalar() {
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = -i;
}
} else {
let arch = Arch::new();
arch.dispatch(NegF64Op { input, output });
}
}
#[inline]
pub fn simd_square_f64(input: &[f64], output: &mut [f64]) {
debug_assert_eq!(input.len(), output.len());
if force_scalar() {
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = i * i;
}
} else {
let arch = Arch::new();
arch.dispatch(SquareF64Op { input, output });
}
}
#[inline]
pub fn simd_reciprocal_f64(input: &[f64], output: &mut [f64]) {
debug_assert_eq!(input.len(), output.len());
if force_scalar() {
for (o, &i) in output.iter_mut().zip(input.iter()) {
*o = 1.0 / i;
}
} else {
let arch = Arch::new();
arch.dispatch(ReciprocalF64Op { input, output });
}
}
struct SqrtF64Op<'a> {
input: &'a [f64],
output: &'a mut [f64],
}
impl pulp::WithSimd for SqrtF64Op<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let n = self.input.len();
let lane_count = size_of::<S::f64s>() / size_of::<f64>();
let stride = lane_count * 4;
let unrolled_end = n - (n % stride);
let simd_end = n - (n % lane_count);
let mut i = 0;
while i < unrolled_end {
let v0 = simd.partial_load_f64s(&self.input[i..i + lane_count]);
let v1 = simd.partial_load_f64s(&self.input[i + lane_count..i + lane_count * 2]);
let v2 = simd.partial_load_f64s(&self.input[i + lane_count * 2..i + lane_count * 3]);
let v3 = simd.partial_load_f64s(&self.input[i + lane_count * 3..i + stride]);
let r0 = simd.sqrt_f64s(v0);
let r1 = simd.sqrt_f64s(v1);
let r2 = simd.sqrt_f64s(v2);
let r3 = simd.sqrt_f64s(v3);
simd.partial_store_f64s(&mut self.output[i..i + lane_count], r0);
simd.partial_store_f64s(&mut self.output[i + lane_count..i + lane_count * 2], r1);
simd.partial_store_f64s(&mut self.output[i + lane_count * 2..i + lane_count * 3], r2);
simd.partial_store_f64s(&mut self.output[i + lane_count * 3..i + stride], r3);
i += stride;
}
while i < simd_end {
let v = simd.partial_load_f64s(&self.input[i..i + lane_count]);
let r = simd.sqrt_f64s(v);
simd.partial_store_f64s(&mut self.output[i..i + lane_count], r);
i += lane_count;
}
for j in simd_end..n {
self.output[j] = self.input[j].sqrt();
}
}
}
struct SqrtF32Op<'a> {
input: &'a [f32],
output: &'a mut [f32],
}
impl pulp::WithSimd for SqrtF32Op<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let n = self.input.len();
let lane_count = size_of::<S::f32s>() / size_of::<f32>();
let simd_end = n - (n % lane_count);
for i in (0..simd_end).step_by(lane_count) {
let v = simd.partial_load_f32s(&self.input[i..i + lane_count]);
let r = simd.sqrt_f32s(v);
simd.partial_store_f32s(&mut self.output[i..i + lane_count], r);
}
for i in simd_end..n {
self.output[i] = self.input[i].sqrt();
}
}
}
struct AbsF64Op<'a> {
input: &'a [f64],
output: &'a mut [f64],
}
impl pulp::WithSimd for AbsF64Op<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let n = self.input.len();
let lane_count = size_of::<S::f64s>() / size_of::<f64>();
let simd_end = n - (n % lane_count);
for i in (0..simd_end).step_by(lane_count) {
let v = simd.partial_load_f64s(&self.input[i..i + lane_count]);
let r = simd.abs_f64s(v);
simd.partial_store_f64s(&mut self.output[i..i + lane_count], r);
}
for i in simd_end..n {
self.output[i] = self.input[i].abs();
}
}
}
struct NegF64Op<'a> {
input: &'a [f64],
output: &'a mut [f64],
}
impl pulp::WithSimd for NegF64Op<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let n = self.input.len();
let lane_count = size_of::<S::f64s>() / size_of::<f64>();
let simd_end = n - (n % lane_count);
for i in (0..simd_end).step_by(lane_count) {
let v = simd.partial_load_f64s(&self.input[i..i + lane_count]);
let r = simd.neg_f64s(v);
simd.partial_store_f64s(&mut self.output[i..i + lane_count], r);
}
for i in simd_end..n {
self.output[i] = -self.input[i];
}
}
}
struct SquareF64Op<'a> {
input: &'a [f64],
output: &'a mut [f64],
}
impl pulp::WithSimd for SquareF64Op<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let n = self.input.len();
let lane_count = size_of::<S::f64s>() / size_of::<f64>();
let simd_end = n - (n % lane_count);
for i in (0..simd_end).step_by(lane_count) {
let v = simd.partial_load_f64s(&self.input[i..i + lane_count]);
let r = simd.mul_f64s(v, v);
simd.partial_store_f64s(&mut self.output[i..i + lane_count], r);
}
for i in simd_end..n {
self.output[i] = self.input[i] * self.input[i];
}
}
}
struct ReciprocalF64Op<'a> {
input: &'a [f64],
output: &'a mut [f64],
}
impl pulp::WithSimd for ReciprocalF64Op<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let n = self.input.len();
let lane_count = size_of::<S::f64s>() / size_of::<f64>();
let simd_end = n - (n % lane_count);
let one = simd.splat_f64s(1.0);
for i in (0..simd_end).step_by(lane_count) {
let v = simd.partial_load_f64s(&self.input[i..i + lane_count]);
let r = simd.div_f64s(one, v);
simd.partial_store_f64s(&mut self.output[i..i + lane_count], r);
}
for i in simd_end..n {
self.output[i] = 1.0 / self.input[i];
}
}
}
struct ExpFastF64Op<'a> {
input: &'a [f64],
output: &'a mut [f64],
}
impl pulp::WithSimd for ExpFastF64Op<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, _simd: S) -> Self::Output {
for i in 0..self.input.len() {
self.output[i] = crate::fast_exp::exp_fast_f64(self.input[i]);
}
}
}
struct ExpFastF32Op<'a> {
input: &'a [f32],
output: &'a mut [f32],
}
impl pulp::WithSimd for ExpFastF32Op<'_> {
type Output = ();
#[inline(always)]
fn with_simd<S: pulp::Simd>(self, _simd: S) -> Self::Output {
for i in 0..self.input.len() {
self.output[i] = crate::fast_exp::exp_fast_f32(self.input[i]);
}
}
}
#[inline]
pub fn dispatch_exp_fast_f64(input: &[f64], output: &mut [f64]) {
debug_assert_eq!(input.len(), output.len());
if force_scalar() {
for i in 0..input.len() {
output[i] = crate::fast_exp::exp_fast_f64(input[i]);
}
} else {
let arch = Arch::new();
arch.dispatch(ExpFastF64Op { input, output });
}
}
#[inline]
pub fn dispatch_exp_fast_f32(input: &[f32], output: &mut [f32]) {
debug_assert_eq!(input.len(), output.len());
if force_scalar() {
for i in 0..input.len() {
output[i] = crate::fast_exp::exp_fast_f32(input[i]);
}
} else {
let arch = Arch::new();
arch.dispatch(ExpFastF32Op { input, output });
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dispatch_unary_f32_works() {
let input = [1.0f32, 4.0, 9.0, 16.0];
let mut output = [0.0f32; 4];
dispatch_unary_f32(&input, &mut output, f32::sqrt);
assert_eq!(output, [1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn dispatch_unary_f64_simd() {
let input = [1.0f64, 4.0, 9.0, 16.0];
let mut output = [0.0f64; 4];
dispatch_unary_f64(&input, &mut output, f64::sqrt);
assert_eq!(output, [1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn dispatch_binary_f32_works() {
let a = [1.0f32, 2.0, 3.0];
let b = [4.0f32, 5.0, 6.0];
let mut out = [0.0f32; 3];
dispatch_binary_f32(&a, &b, &mut out, |x, y| x + y);
assert_eq!(out, [5.0, 7.0, 9.0]);
}
#[test]
fn dispatch_binary_f64_works() {
let a = [1.0f64, 2.0, 3.0];
let b = [4.0f64, 5.0, 6.0];
let mut out = [0.0f64; 3];
dispatch_binary_f64(&a, &b, &mut out, |x, y| x * y);
assert_eq!(out, [4.0, 10.0, 18.0]);
}
#[test]
fn force_scalar_env() {
assert!(!force_scalar());
}
#[cfg(feature = "f16")]
#[test]
fn dispatch_unary_f16_works() {
let input = [
half::f16::from_f32(1.0),
half::f16::from_f32(4.0),
half::f16::from_f32(9.0),
half::f16::from_f32(16.0),
];
let mut output = [half::f16::ZERO; 4];
super::dispatch_unary_f16(&input, &mut output, f32::sqrt);
let expected = [1.0f32, 2.0, 3.0, 4.0];
for (o, &e) in output.iter().zip(expected.iter()) {
assert!((o.to_f32() - e).abs() < 0.01);
}
}
#[cfg(feature = "f16")]
#[test]
fn dispatch_binary_f16_works() {
let a = [
half::f16::from_f32(1.0),
half::f16::from_f32(2.0),
half::f16::from_f32(3.0),
];
let b = [
half::f16::from_f32(4.0),
half::f16::from_f32(5.0),
half::f16::from_f32(6.0),
];
let mut out = [half::f16::ZERO; 3];
super::dispatch_binary_f16(&a, &b, &mut out, |x, y| x + y);
let expected = [5.0f32, 7.0, 9.0];
for (o, &e) in out.iter().zip(expected.iter()) {
assert!((o.to_f32() - e).abs() < 0.01);
}
}
}