use lazy_static::lazy_static;
use wide::{f32x4, f32x8};
lazy_static! {
pub static ref SIMD: SimdDispatcher = SimdDispatcher::detect();
}
pub trait SimdLanes: Copy + Clone + Sized {
const LANES: usize;
fn splat(val: f32) -> Self;
fn from_array(arr: &[f32]) -> Self;
fn add(self, other: Self) -> Self;
fn sub(self, other: Self) -> Self;
fn mul(self, other: Self) -> Self;
fn div(self, other: Self) -> Self;
fn abs(self) -> Self;
fn min(self, other: Self) -> Self;
fn max(self, other: Self) -> Self;
fn sqrt(self) -> Self;
fn mul_add(self, b: Self, c: Self) -> Self;
fn fast_tanh(self) -> Self;
fn fast_sin(self) -> Self;
fn fast_cos(self) -> Self;
fn fast_atan2(y: Self, x: Self) -> Self;
fn clamp(self, min: Self, max: Self) -> Self;
fn write_to_slice(self, slice: &mut [f32]);
fn extract_lane(self, lane: usize) -> f32;
}
macro_rules! impl_simd_lanes {
($type:ty, $lanes:expr) => {
impl SimdLanes for $type {
const LANES: usize = $lanes;
#[inline(always)]
fn splat(val: f32) -> Self {
<$type>::splat(val)
}
#[inline(always)]
fn from_array(arr: &[f32]) -> Self {
debug_assert!(arr.len() >= $lanes, "Array too short for SIMD width");
let mut fixed = [0.0f32; $lanes];
fixed.copy_from_slice(&arr[..$lanes]);
<$type>::from(fixed)
}
#[inline(always)]
fn add(self, other: Self) -> Self {
self + other
}
#[inline(always)]
fn sub(self, other: Self) -> Self {
self - other
}
#[inline(always)]
fn mul(self, other: Self) -> Self {
self * other
}
#[inline(always)]
fn div(self, other: Self) -> Self {
self / other
}
#[inline(always)]
fn abs(self) -> Self {
self.abs()
}
#[inline(always)]
fn min(self, other: Self) -> Self {
self.min(other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
self.max(other)
}
#[inline(always)]
fn sqrt(self) -> Self {
self.sqrt()
}
#[inline(always)]
fn mul_add(self, b: Self, c: Self) -> Self {
self.mul_add(b, c)
}
#[inline(always)]
fn fast_tanh(self) -> Self {
let x2 = self.mul(self);
let num = self.mul(x2.add(<$type>::splat(15.0)));
let denom = x2.mul(<$type>::splat(6.0)).add(<$type>::splat(15.0));
let result = num.div(denom);
result.clamp(<$type>::splat(-1.0), <$type>::splat(1.0))
}
#[inline(always)]
fn fast_sin(self) -> Self {
let x = self;
let x2 = x.mul(x);
let x3 = x2.mul(x);
let x5 = x3.mul(x2);
let x7 = x5.mul(x2);
let term1 = x;
let term2 = x3.mul(<$type>::splat(-1.0 / 6.0));
let term3 = x5.mul(<$type>::splat(1.0 / 120.0));
let term4 = x7.mul(<$type>::splat(-1.0 / 5040.0));
term1.add(term2).add(term3).add(term4)
}
#[inline(always)]
fn fast_cos(self) -> Self {
let x2 = self.mul(self);
let x4 = x2.mul(x2);
let x6 = x4.mul(x2);
let term1 = <$type>::splat(1.0);
let term2 = x2.mul(<$type>::splat(-1.0 / 2.0));
let term3 = x4.mul(<$type>::splat(1.0 / 24.0));
let term4 = x6.mul(<$type>::splat(-1.0 / 720.0));
term1.add(term2).add(term3).add(term4)
}
#[inline(always)]
fn fast_atan2(y: Self, x: Self) -> Self {
let abs_x = x.abs();
let abs_y = y.abs();
let a = abs_y.min(abs_x).div(abs_x.max(abs_y).max(<$type>::splat(1e-10)));
let s = a.mul(a);
let r = a.mul(<$type>::splat(0.99997726))
.add(s.mul(<$type>::splat(-0.33262347)))
.add(s.mul(s).mul(<$type>::splat(0.19354346)))
.add(s.mul(s).mul(s).mul(<$type>::splat(-0.11643287)))
.add(s.mul(s).mul(s).mul(s).mul(<$type>::splat(0.05265332)))
.add(s.mul(s).mul(s).mul(s).mul(s).mul(<$type>::splat(-0.011_721_2)));
let pi_2 = <$type>::splat(std::f32::consts::FRAC_PI_2);
let r = pi_2.sub(r).mul((abs_y.sub(abs_x)).max(<$type>::splat(0.0)))
.add(r.mul((abs_x.sub(abs_y)).max(<$type>::splat(0.0))));
let pi = <$type>::splat(std::f32::consts::PI);
let r = r.mul((x.sub(<$type>::splat(0.0))).max(<$type>::splat(0.0)))
.add(pi.sub(r).mul((<$type>::splat(0.0).sub(x)).max(<$type>::splat(0.0)))
.mul((y.sub(<$type>::splat(0.0))).max(<$type>::splat(0.0))))
.add(r.sub(pi).mul((<$type>::splat(0.0).sub(x)).max(<$type>::splat(0.0)))
.mul((<$type>::splat(0.0).sub(y)).max(<$type>::splat(0.0))));
r.mul((y.sub(<$type>::splat(0.0))).max(<$type>::splat(0.0)))
.add(r.mul(<$type>::splat(-1.0)).mul((<$type>::splat(0.0).sub(y)).max(<$type>::splat(0.0))))
}
#[inline(always)]
fn clamp(self, min: Self, max: Self) -> Self {
self.max(min).min(max)
}
#[inline(always)]
fn write_to_slice(self, slice: &mut [f32]) {
let arr = self.to_array();
slice[..$lanes].copy_from_slice(&arr);
}
#[inline(always)]
fn extract_lane(self, lane: usize) -> f32 {
self.to_array()[lane]
}
}
};
}
impl_simd_lanes!(f32x8, 8); impl_simd_lanes!(f32x4, 4);
impl SimdLanes for f32 {
const LANES: usize = 1;
#[inline(always)]
fn splat(val: f32) -> Self {
val
}
#[inline(always)]
fn from_array(arr: &[f32]) -> Self {
arr[0]
}
#[inline(always)]
fn add(self, other: Self) -> Self {
self + other
}
#[inline(always)]
fn sub(self, other: Self) -> Self {
self - other
}
#[inline(always)]
fn mul(self, other: Self) -> Self {
self * other
}
#[inline(always)]
fn div(self, other: Self) -> Self {
self / other
}
#[inline(always)]
fn abs(self) -> Self {
self.abs()
}
#[inline(always)]
fn min(self, other: Self) -> Self {
self.min(other)
}
#[inline(always)]
fn max(self, other: Self) -> Self {
self.max(other)
}
#[inline(always)]
fn sqrt(self) -> Self {
self.sqrt()
}
#[inline(always)]
fn mul_add(self, b: Self, c: Self) -> Self {
self.mul_add(b, c)
}
#[inline(always)]
fn fast_tanh(self) -> Self {
let x2 = self * self;
let num = self * (x2 + 15.0);
let denom = 15.0 + 6.0 * x2;
let result = num / denom;
result.clamp(-1.0, 1.0)
}
#[inline(always)]
fn fast_sin(self) -> Self {
let x = self;
let x2 = x * x;
let x3 = x2 * x;
let x5 = x3 * x2;
let x7 = x5 * x2;
x - x3 / 6.0 + x5 / 120.0 - x7 / 5040.0
}
#[inline(always)]
fn fast_cos(self) -> Self {
let x2 = self * self;
let x4 = x2 * x2;
let x6 = x4 * x2;
1.0 - x2 / 2.0 + x4 / 24.0 - x6 / 720.0
}
#[inline(always)]
fn fast_atan2(y: Self, x: Self) -> Self {
let abs_x = x.abs();
let abs_y = y.abs();
let a = abs_y.min(abs_x) / abs_x.max(abs_y).max(1e-10);
let s = a * a;
let mut r = a * 0.99997726
+ s * -0.33262347
+ s * s * 0.19354346
+ s * s * s * -0.11643287
+ s * s * s * s * 0.05265332
+ s * s * s * s * s * -0.011_721_2;
if abs_y > abs_x {
r = std::f32::consts::FRAC_PI_2 - r;
}
if x < 0.0 {
r = if y >= 0.0 {
std::f32::consts::PI - r
} else {
-std::f32::consts::PI + r
};
}
if y < 0.0 {
-r
} else {
r
}
}
#[inline(always)]
fn clamp(self, min: Self, max: Self) -> Self {
self.max(min).min(max)
}
#[inline(always)]
fn write_to_slice(self, slice: &mut [f32]) {
slice[0] = self;
}
#[inline(always)]
fn extract_lane(self, _lane: usize) -> f32 {
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SimdWidth {
X8,
X4,
Scalar,
}
pub struct SimdDispatcher {
width: SimdWidth,
}
impl SimdDispatcher {
pub fn detect() -> Self {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return Self {
width: SimdWidth::X8,
};
}
Self {
width: SimdWidth::X4,
}
}
#[cfg(not(target_arch = "x86_64"))]
{
Self {
width: SimdWidth::X4,
}
}
}
pub fn width(&self) -> usize {
match self.width {
SimdWidth::X8 => 8,
SimdWidth::X4 => 4,
SimdWidth::Scalar => 1,
}
}
pub fn simd_width(&self) -> SimdWidth {
self.width
}
#[inline]
pub fn multiply_const(&self, buffer: &mut [f32], multiplier: f32) {
match self.width {
SimdWidth::X8 => self.multiply_const_impl::<f32x8>(buffer, multiplier),
SimdWidth::X4 => self.multiply_const_impl::<f32x4>(buffer, multiplier),
SimdWidth::Scalar => self.multiply_const_impl::<f32>(buffer, multiplier),
}
}
#[inline(always)]
fn multiply_const_impl<V: SimdLanes>(&self, buffer: &mut [f32], multiplier: f32) {
let mult_vec = V::splat(multiplier);
let (chunks, remainder) = buffer.split_at_mut(buffer.len() - (buffer.len() % V::LANES));
for chunk in chunks.chunks_exact_mut(V::LANES) {
let vec = V::from_array(chunk);
let result = vec.mul(mult_vec);
result.write_to_slice(chunk);
}
for sample in remainder.iter_mut() {
*sample *= multiplier;
}
}
#[inline]
pub fn multiply_buffers(&self, buffer: &mut [f32], modulation: &[f32]) {
match self.width {
SimdWidth::X8 => self.multiply_buffers_impl::<f32x8>(buffer, modulation),
SimdWidth::X4 => self.multiply_buffers_impl::<f32x4>(buffer, modulation),
SimdWidth::Scalar => self.multiply_buffers_impl::<f32>(buffer, modulation),
}
}
#[inline(always)]
fn multiply_buffers_impl<V: SimdLanes>(&self, buffer: &mut [f32], modulation: &[f32]) {
let len = buffer.len().min(modulation.len());
let (buffer_chunks, buffer_rem) = buffer[..len].split_at_mut(len - (len % V::LANES));
let (mod_chunks, mod_rem) = modulation[..len].split_at(len - (len % V::LANES));
for (buf_chunk, mod_chunk) in buffer_chunks
.chunks_exact_mut(V::LANES)
.zip(mod_chunks.chunks_exact(V::LANES))
{
let buf_vec = V::from_array(buf_chunk);
let mod_vec = V::from_array(mod_chunk);
let result = buf_vec.mul(mod_vec);
result.write_to_slice(buf_chunk);
}
for i in 0..buffer_rem.len() {
buffer_rem[i] *= mod_rem[i];
}
}
#[inline]
pub fn fma(&self, buffer: &mut [f32], mul: f32, add: f32) {
match self.width {
SimdWidth::X8 => self.fma_impl::<f32x8>(buffer, mul, add),
SimdWidth::X4 => self.fma_impl::<f32x4>(buffer, mul, add),
SimdWidth::Scalar => self.fma_impl::<f32>(buffer, mul, add),
}
}
#[inline(always)]
fn fma_impl<V: SimdLanes>(&self, buffer: &mut [f32], mul: f32, add: f32) {
let mul_vec = V::splat(mul);
let add_vec = V::splat(add);
let (chunks, remainder) = buffer.split_at_mut(buffer.len() - (buffer.len() % V::LANES));
for chunk in chunks.chunks_exact_mut(V::LANES) {
let vec = V::from_array(chunk);
let result = vec.mul_add(mul_vec, add_vec);
result.write_to_slice(chunk);
}
for sample in remainder.iter_mut() {
*sample = sample.mul_add(mul, add);
}
}
#[inline]
pub fn apply_fast_tanh(&self, buffer: &mut [f32]) {
match self.width {
SimdWidth::X8 => self.apply_fast_tanh_impl::<f32x8>(buffer),
SimdWidth::X4 => self.apply_fast_tanh_impl::<f32x4>(buffer),
SimdWidth::Scalar => self.apply_fast_tanh_impl::<f32>(buffer),
}
}
#[inline(always)]
fn apply_fast_tanh_impl<V: SimdLanes>(&self, buffer: &mut [f32]) {
let (chunks, remainder) = buffer.split_at_mut(buffer.len() - (buffer.len() % V::LANES));
for chunk in chunks.chunks_exact_mut(V::LANES) {
let vec = V::from_array(chunk);
let result = vec.fast_tanh();
result.write_to_slice(chunk);
}
for sample in remainder.iter_mut() {
*sample = sample.fast_tanh();
}
}
#[inline]
pub fn lerp_buffers(&self, out: &mut [f32], a: &[f32], b: &[f32], t: &[f32]) {
match self.width {
SimdWidth::X8 => self.lerp_impl::<f32x8>(out, a, b, t),
SimdWidth::X4 => self.lerp_impl::<f32x4>(out, a, b, t),
SimdWidth::Scalar => self.lerp_impl::<f32>(out, a, b, t),
}
}
#[inline(always)]
fn lerp_impl<V: SimdLanes>(&self, out: &mut [f32], a: &[f32], b: &[f32], t: &[f32]) {
let len = out.len().min(a.len()).min(b.len()).min(t.len());
let (out_chunks, out_rem) = out[..len].split_at_mut(len - (len % V::LANES));
let (a_chunks, a_rem) = a[..len].split_at(len - (len % V::LANES));
let (b_chunks, b_rem) = b[..len].split_at(len - (len % V::LANES));
let (t_chunks, t_rem) = t[..len].split_at(len - (len % V::LANES));
for (((out_chunk, a_chunk), b_chunk), t_chunk) in out_chunks
.chunks_exact_mut(V::LANES)
.zip(a_chunks.chunks_exact(V::LANES))
.zip(b_chunks.chunks_exact(V::LANES))
.zip(t_chunks.chunks_exact(V::LANES))
{
let va = V::from_array(a_chunk);
let vb = V::from_array(b_chunk);
let vt = V::from_array(t_chunk);
let diff = vb.sub(va);
let result = diff.mul_add(vt, va);
result.write_to_slice(out_chunk);
}
for i in 0..out_rem.len() {
out_rem[i] = a_rem[i] + (b_rem[i] - a_rem[i]) * t_rem[i];
}
}
#[inline]
pub fn sum_of_squares(&self, buffer: &[f32]) -> f32 {
match self.width {
SimdWidth::X8 => self.sum_of_squares_impl::<f32x8>(buffer),
SimdWidth::X4 => self.sum_of_squares_impl::<f32x4>(buffer),
SimdWidth::Scalar => self.sum_of_squares_impl::<f32>(buffer),
}
}
#[inline(always)]
fn sum_of_squares_impl<V: SimdLanes>(&self, buffer: &[f32]) -> f32 {
let lanes = V::LANES;
let chunks = buffer.len() / lanes;
let remainder_start = chunks * lanes;
let mut accumulator = V::splat(0.0);
for chunk_idx in 0..chunks {
let idx = chunk_idx * lanes;
let vec = V::from_array(&buffer[idx..idx + lanes]);
accumulator = vec.mul_add(vec, accumulator);
}
let mut sum = 0.0;
for i in 0..lanes {
sum += accumulator.extract_lane(i);
}
for &sample in &buffer[remainder_start..] {
sum += sample * sample;
}
sum
}
#[inline]
pub fn mix_mono_to_stereo(
&self,
output: &mut [f32],
input: &[f32],
left_gain: f32,
right_gain: f32,
) {
match self.width {
SimdWidth::X8 => self.mix_mono_to_stereo_impl::<f32x8>(output, input, left_gain, right_gain),
SimdWidth::X4 => self.mix_mono_to_stereo_impl::<f32x4>(output, input, left_gain, right_gain),
SimdWidth::Scalar => self.mix_mono_to_stereo_impl::<f32>(output, input, left_gain, right_gain),
}
}
#[inline(always)]
fn mix_mono_to_stereo_impl<V: SimdLanes>(
&self,
output: &mut [f32],
input: &[f32],
left_gain: f32,
right_gain: f32,
) {
let num_frames = input.len().min(output.len() / 2);
let lanes = V::LANES;
let chunks = num_frames / lanes;
let remainder_start = chunks * lanes;
let left_gain_vec = V::splat(left_gain);
let right_gain_vec = V::splat(right_gain);
for chunk_idx in 0..chunks {
let mono_idx = chunk_idx * lanes;
let stereo_idx = chunk_idx * lanes * 2;
let mut out_left = [0.0f32; 8];
let mut out_right = [0.0f32; 8];
for i in 0..lanes {
out_left[i] = output[stereo_idx + i * 2];
out_right[i] = output[stereo_idx + i * 2 + 1];
}
let mono_vec = V::from_array(&input[mono_idx..mono_idx + lanes]);
let out_left_vec = V::from_array(&out_left[..lanes]);
let out_right_vec = V::from_array(&out_right[..lanes]);
let mixed_left = mono_vec.mul_add(left_gain_vec, out_left_vec);
let mixed_right = mono_vec.mul_add(right_gain_vec, out_right_vec);
mixed_left.write_to_slice(&mut out_left[..lanes]);
mixed_right.write_to_slice(&mut out_right[..lanes]);
for i in 0..lanes {
output[stereo_idx + i * 2] = out_left[i];
output[stereo_idx + i * 2 + 1] = out_right[i];
}
}
for (frame_idx, &input_sample) in input.iter().enumerate().take(num_frames).skip(remainder_start) {
let stereo_idx = frame_idx * 2;
output[stereo_idx] += input_sample * left_gain;
output[stereo_idx + 1] += input_sample * right_gain;
}
}
#[inline]
pub fn mix_stereo_interleaved(
&self,
output: &mut [f32],
input: &[f32],
left_gain: f32,
right_gain: f32,
) {
match self.width {
SimdWidth::X8 => {
self.mix_stereo_impl::<f32x8>(output, input, left_gain, right_gain)
}
SimdWidth::X4 => {
self.mix_stereo_impl::<f32x4>(output, input, left_gain, right_gain)
}
SimdWidth::Scalar => {
self.mix_stereo_impl::<f32>(output, input, left_gain, right_gain)
}
}
}
#[inline(always)]
fn mix_stereo_impl<V: SimdLanes>(
&self,
output: &mut [f32],
input: &[f32],
left_gain: f32,
right_gain: f32,
) {
let num_frames = output.len().min(input.len()) / 2;
let lanes = V::LANES;
let chunks = num_frames / lanes;
let remainder_start = chunks * lanes;
let left_gain_vec = V::splat(left_gain);
let right_gain_vec = V::splat(right_gain);
for chunk_idx in 0..chunks {
let frame_start = chunk_idx * lanes;
let idx = frame_start * 2;
let mut input_left = [0.0f32; 8];
let mut input_right = [0.0f32; 8];
let mut output_left = [0.0f32; 8];
let mut output_right = [0.0f32; 8];
for i in 0..lanes {
input_left[i] = input[idx + i * 2];
input_right[i] = input[idx + i * 2 + 1];
output_left[i] = output[idx + i * 2];
output_right[i] = output[idx + i * 2 + 1];
}
let in_left = V::from_array(&input_left[..lanes]);
let in_right = V::from_array(&input_right[..lanes]);
let out_left = V::from_array(&output_left[..lanes]);
let out_right = V::from_array(&output_right[..lanes]);
let mixed_left = in_left.mul_add(left_gain_vec, out_left);
let mixed_right = in_right.mul_add(right_gain_vec, out_right);
mixed_left.write_to_slice(&mut output_left[..lanes]);
mixed_right.write_to_slice(&mut output_right[..lanes]);
for i in 0..lanes {
output[idx + i * 2] = output_left[i];
output[idx + i * 2 + 1] = output_right[i];
}
}
for frame_idx in remainder_start..num_frames {
let idx = frame_idx * 2;
output[idx] += input[idx] * left_gain;
output[idx + 1] += input[idx + 1] * right_gain;
}
}
pub fn deinterleave_stereo(&self, interleaved: &[f32], left: &mut [f32], right: &mut [f32]) -> usize {
match self.width {
SimdWidth::X8 => self.deinterleave_stereo_impl::<f32x8>(interleaved, left, right),
SimdWidth::X4 => self.deinterleave_stereo_impl::<f32x4>(interleaved, left, right),
SimdWidth::Scalar => self.deinterleave_stereo_impl::<f32>(interleaved, left, right),
}
}
#[inline(always)]
fn deinterleave_stereo_impl<V: SimdLanes>(
&self,
interleaved: &[f32],
left: &mut [f32],
right: &mut [f32],
) -> usize {
let lanes = V::LANES;
let samples_needed = lanes * 2;
if interleaved.len() < samples_needed {
return 0;
}
let mut left_arr = [0.0f32; 8];
let mut right_arr = [0.0f32; 8];
for i in 0..lanes {
left_arr[i] = interleaved[i * 2];
right_arr[i] = interleaved[i * 2 + 1];
}
let left_vec = V::from_array(&left_arr[..lanes]);
let right_vec = V::from_array(&right_arr[..lanes]);
left_vec.write_to_slice(&mut left[..lanes]);
right_vec.write_to_slice(&mut right[..lanes]);
lanes
}
pub fn interleave_stereo(&self, left: &[f32], right: &[f32], output: &mut [f32]) -> usize {
match self.width {
SimdWidth::X8 => self.interleave_stereo_impl::<f32x8>(left, right, output),
SimdWidth::X4 => self.interleave_stereo_impl::<f32x4>(left, right, output),
SimdWidth::Scalar => self.interleave_stereo_impl::<f32>(left, right, output),
}
}
#[inline(always)]
fn interleave_stereo_impl<V: SimdLanes>(
&self,
left: &[f32],
right: &[f32],
output: &mut [f32],
) -> usize {
let lanes = V::LANES;
let samples_needed = lanes * 2;
if left.len() < lanes || right.len() < lanes || output.len() < samples_needed {
return 0;
}
let left_vec = V::from_array(&left[..lanes]);
let right_vec = V::from_array(&right[..lanes]);
let mut left_arr = [0.0f32; 8];
let mut right_arr = [0.0f32; 8];
left_vec.write_to_slice(&mut left_arr[..lanes]);
right_vec.write_to_slice(&mut right_arr[..lanes]);
for i in 0..lanes {
output[i * 2] = left_arr[i];
output[i * 2 + 1] = right_arr[i];
}
lanes
}
}
impl Default for SimdDispatcher {
fn default() -> Self {
Self::detect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simd_detection() {
let simd = SimdDispatcher::detect();
let width = simd.width();
assert!(width >= 1);
assert!(width <= 8);
println!("Detected SIMD width: {}", width);
}
#[test]
fn test_scalar_lanes() {
let a = f32::splat(2.0);
let b = f32::splat(3.0);
assert_eq!(a.add(b), 5.0);
assert_eq!(a.mul(b), 6.0);
}
#[test]
fn test_f32x4_lanes() {
let a = f32x4::splat(2.0);
let b = f32x4::splat(3.0);
let result = a.add(b);
let arr = result.to_array();
assert_eq!(arr, [5.0, 5.0, 5.0, 5.0]);
}
#[test]
fn test_f32x8_lanes() {
let a = f32x8::splat(2.0);
let b = f32x8::splat(3.0);
let result = a.mul(b);
let arr = result.to_array();
assert_eq!(arr, [6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0]);
}
#[test]
fn test_write_to_slice() {
let vec = f32x4::splat(42.0);
let mut buffer = vec![0.0; 4];
vec.write_to_slice(&mut buffer);
assert_eq!(buffer, vec![42.0, 42.0, 42.0, 42.0]);
}
#[test]
fn test_sqrt() {
let vec = f32x4::splat(16.0);
let result = vec.sqrt();
let arr = result.to_array();
assert_eq!(arr, [4.0, 4.0, 4.0, 4.0]);
}
#[test]
fn test_mul_add() {
let a = f32x4::splat(2.0);
let b = f32x4::splat(3.0);
let c = f32x4::splat(1.0);
let result = a.mul_add(b, c); let arr = result.to_array();
assert_eq!(arr, [7.0, 7.0, 7.0, 7.0]);
}
#[test]
fn test_fast_tanh() {
let vec = f32x4::splat(1.0);
let result = vec.fast_tanh();
let arr = result.to_array();
for &val in &arr {
assert!((val - 1.0f32.tanh()).abs() < 0.01, "fast_tanh accuracy check");
}
}
#[test]
fn test_clamp() {
let vec = f32x4::from([0.5, 1.5, -0.5, 2.5]);
let result = vec.clamp(f32x4::splat(0.0), f32x4::splat(2.0));
let arr = result.to_array();
assert_eq!(arr, [0.5, 1.5, 0.0, 2.0]);
}
#[test]
fn test_simd_multiply_const() {
use crate::synthesis::simd::SIMD;
let mut buffer = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
SIMD.multiply_const(&mut buffer, 2.0);
assert_eq!(buffer, vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0]);
}
#[test]
fn test_simd_fma() {
use crate::synthesis::simd::SIMD;
let mut buffer = vec![1.0, 2.0, 3.0, 4.0];
SIMD.fma(&mut buffer, 2.0, 1.0);
assert_eq!(buffer, vec![3.0, 5.0, 7.0, 9.0]);
}
#[test]
fn test_simd_fast_tanh() {
use crate::synthesis::simd::SIMD;
let mut buffer = vec![0.0, 1.0, -1.0, 0.5];
SIMD.apply_fast_tanh(&mut buffer);
assert!((buffer[0] - 0.0f32.tanh()).abs() < 0.01);
assert!((buffer[1] - 1.0f32.tanh()).abs() < 0.01);
assert!((buffer[2] - (-1.0f32).tanh()).abs() < 0.01);
assert!((buffer[3] - 0.5f32.tanh()).abs() < 0.01);
}
#[test]
fn test_simd_non_aligned_buffer() {
use crate::synthesis::simd::SIMD;
let mut buffer = vec![1.0, 2.0, 3.0];
SIMD.multiply_const(&mut buffer, 2.0);
assert_eq!(buffer, vec![2.0, 4.0, 6.0]);
}
#[test]
fn test_fast_sin_accuracy() {
use std::f32::consts::PI;
assert!((f32::fast_sin(0.0) - 0.0).abs() < 0.001);
assert!((f32::fast_sin(PI / 2.0) - 1.0).abs() < 0.001);
assert!((f32::fast_sin(-PI / 2.0) - (-1.0)).abs() < 0.001);
assert!((f32::fast_sin(PI / 4.0) - (PI / 4.0).sin()).abs() < 0.001);
let test_vals = [0.0, PI / 2.0, -PI / 2.0, PI / 4.0];
let vec = f32x4::from_array(&test_vals);
let result = vec.fast_sin();
let result_arr = result.to_array();
assert!((result_arr[0] - 0.0).abs() < 0.001);
assert!((result_arr[1] - 1.0).abs() < 0.001);
assert!((result_arr[2] - (-1.0)).abs() < 0.001);
assert!((result_arr[3] - (PI / 4.0).sin()).abs() < 0.001);
}
#[test]
fn test_fast_cos_accuracy() {
use std::f32::consts::PI;
assert!((f32::fast_cos(0.0) - 1.0).abs() < 0.001);
assert!((f32::fast_cos(PI) - (-1.0)).abs() < 0.25); assert!((f32::fast_cos(PI / 2.0) - 0.0).abs() < 0.001);
assert!((f32::fast_cos(PI / 4.0) - (PI / 4.0).cos()).abs() < 0.001);
let test_vals = [0.0, PI, PI / 2.0, PI / 4.0, -PI / 4.0, PI / 6.0, 0.5, 1.0];
let vec = f32x8::from_array(&test_vals);
let result = vec.fast_cos();
let result_arr = result.to_array();
assert!((result_arr[0] - 1.0).abs() < 0.001);
assert!((result_arr[1] - (-1.0)).abs() < 0.25);
assert!((result_arr[2] - 0.0).abs() < 0.001);
assert!((result_arr[3] - (PI / 4.0).cos()).abs() < 0.001);
}
#[test]
fn test_fast_sincos_vs_stdlib() {
use std::f32::consts::PI;
let test_values = [
-PI / 2.0, -PI / 4.0, -PI / 6.0, 0.0, PI / 6.0, PI / 4.0, PI / 2.0,
];
for &x in &test_values {
let fast_sin = f32::fast_sin(x);
let std_sin = x.sin();
assert!(
(fast_sin - std_sin).abs() < 0.002,
"fast_sin({}) = {} vs sin({}) = {}, error = {}",
x,
fast_sin,
std_sin,
x,
(fast_sin - std_sin).abs()
);
let fast_cos = f32::fast_cos(x);
let std_cos = x.cos();
assert!(
(fast_cos - std_cos).abs() < 0.002,
"fast_cos({}) = {} vs cos({}) = {}, error = {}",
x,
fast_cos,
std_cos,
x,
(fast_cos - std_cos).abs()
);
}
assert!((f32::fast_sin(PI) - PI.sin()).abs() < 0.08);
assert!((f32::fast_sin(-PI) - (-PI).sin()).abs() < 0.08);
assert!((f32::fast_cos(PI) - PI.cos()).abs() < 0.25); }
}