simdeez 3.0.1

SIMD library to abstract over different instruction sets and widths
Documentation
use criterion::{Criterion, Throughput};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use simdeez::prelude::*;
use simdeez::scalar::Scalar;
use std::hint::black_box;

pub const INPUT_LEN: usize = 1 << 20;

pub fn make_positive_log_inputs(len: usize, seed: u64) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len)
        .map(|_| {
            let log2x = rng.gen_range(-20.0f32..20.0f32);
            let mantissa = rng.gen_range(1.0f32..2.0f32);
            mantissa * log2x.exp2()
        })
        .collect()
}

pub fn make_positive_log_inputs_f64(len: usize, seed: u64) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len)
        .map(|_| {
            let log2x = rng.gen_range(-40.0f64..40.0f64);
            let mantissa = rng.gen_range(1.0f64..2.0f64);
            mantissa * log2x.exp2()
        })
        .collect()
}

pub fn make_exp2_inputs(len: usize, seed: u64) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len)
        .map(|_| rng.gen_range(-100.0f32..100.0f32))
        .collect()
}

pub fn make_exp2_inputs_f64(len: usize, seed: u64) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len)
        .map(|_| rng.gen_range(-1000.0f64..1000.0f64))
        .collect()
}

pub fn make_exp_inputs(len: usize, seed: u64) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(-80.0f32..80.0f32)).collect()
}

pub fn make_exp_inputs_f64(len: usize, seed: u64) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len)
        .map(|_| rng.gen_range(-700.0f64..700.0f64))
        .collect()
}

pub fn make_unary_inputs(len: usize, seed: u64, range: core::ops::Range<f32>) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(range.clone())).collect()
}

pub fn make_trig_inputs(len: usize, seed: u64) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len)
        .map(|_| rng.gen_range(-20.0f32 * core::f32::consts::PI..20.0f32 * core::f32::consts::PI))
        .collect()
}

pub fn make_trig_inputs_f64(len: usize, seed: u64) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len)
        .map(|_| rng.gen_range(-100.0f64 * core::f64::consts::PI..100.0f64 * core::f64::consts::PI))
        .collect()
}

pub fn make_inverse_trig_inputs(len: usize, seed: u64) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(-1.0f32..1.0f32)).collect()
}

pub fn make_inverse_trig_inputs_f64(len: usize, seed: u64) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(-1.0f64..1.0f64)).collect()
}

pub fn make_atan_inputs(len: usize, seed: u64) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(-64.0f32..64.0f32)).collect()
}

pub fn make_atan_inputs_f64(len: usize, seed: u64) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(-64.0f64..64.0f64)).collect()
}

pub fn make_tan_inputs(len: usize, seed: u64) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    let half_pi = core::f32::consts::FRAC_PI_2;
    (0..len)
        .map(|_| {
            let mut x =
                rng.gen_range(-20.0f32 * core::f32::consts::PI..20.0f32 * core::f32::consts::PI);
            let k = (x / core::f32::consts::PI).round();
            let nearest_pole = (k + 0.5) * core::f32::consts::PI;
            if (x - nearest_pole).abs() < 1.0e-3 {
                x += if x >= 0.0 { 2.5e-3 } else { -2.5e-3 };
            }
            if x == half_pi || x == -half_pi {
                x += 2.5e-3;
            }
            x
        })
        .collect()
}

pub fn make_tan_inputs_f64(len: usize, seed: u64) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len)
        .map(|_| {
            let mut x =
                rng.gen_range(-100.0f64 * core::f64::consts::PI..100.0f64 * core::f64::consts::PI);
            let k = (x / core::f64::consts::PI).round();
            let nearest_pole = (k + 0.5) * core::f64::consts::PI;
            if (x - nearest_pole).abs() < 1.0e-8 {
                x += if x >= 0.0 { 2.5e-8 } else { -2.5e-8 };
            }
            x
        })
        .collect()
}

pub struct BenchTargets {
    pub scalar_native: fn(&[f32]) -> f32,
    pub simdeez_runtime: fn(&[f32]) -> f32,
    pub simdeez_scalar: fn(&[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse2: unsafe fn(&[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse41: unsafe fn(&[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx2: unsafe fn(&[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx512: unsafe fn(&[f32]) -> f32,
}

pub struct BenchTargetsF64 {
    pub scalar_native: fn(&[f64]) -> f64,
    pub simdeez_runtime: fn(&[f64]) -> f64,
    pub simdeez_scalar: fn(&[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse2: unsafe fn(&[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse41: unsafe fn(&[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx2: unsafe fn(&[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx512: unsafe fn(&[f64]) -> f64,
}

pub fn bench_variants(c: &mut Criterion, group_name: &str, input: &[f32], targets: BenchTargets) {
    let mut group = c.benchmark_group(group_name);
    group.throughput(Throughput::Elements(input.len() as u64));

    group.bench_function("scalar-native", |b| {
        b.iter(|| black_box((targets.scalar_native)(black_box(input))))
    });

    group.bench_function("simdeez-runtime", |b| {
        b.iter(|| black_box((targets.simdeez_runtime)(black_box(input))))
    });

    group.bench_function("simdeez-forced-scalar", |b| {
        b.iter(|| black_box((targets.simdeez_scalar)(black_box(input))))
    });

    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if std::is_x86_feature_detected!("sse2") {
            group.bench_function("simdeez-forced-sse2", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_sse2)(black_box(input))) })
            });
        } else {
            eprintln!("[bench] skipped simdeez-forced-sse2 for {group_name}: CPU lacks sse2");
        }

        if std::is_x86_feature_detected!("sse4.1") {
            group.bench_function("simdeez-forced-sse41", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_sse41)(black_box(input))) })
            });
        } else {
            eprintln!("[bench] skipped simdeez-forced-sse41 for {group_name}: CPU lacks sse4.1");
        }

        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
            group.bench_function("simdeez-forced-avx2", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_avx2)(black_box(input))) })
            });
        } else {
            eprintln!("[bench] skipped simdeez-forced-avx2 for {group_name}: CPU lacks avx2/fma");
        }

        if std::is_x86_feature_detected!("avx512f")
            && std::is_x86_feature_detected!("avx512bw")
            && std::is_x86_feature_detected!("avx512dq")
        {
            group.bench_function("simdeez-forced-avx512", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_avx512)(black_box(input))) })
            });
        } else {
            eprintln!(
                "[bench] skipped simdeez-forced-avx512 for {group_name}: CPU lacks avx512f+bw+dq"
            );
        }
    }

    group.finish();
}

pub fn bench_variants_f64(
    c: &mut Criterion,
    group_name: &str,
    input: &[f64],
    targets: BenchTargetsF64,
) {
    let mut group = c.benchmark_group(group_name);
    group.throughput(Throughput::Elements(input.len() as u64));

    group.bench_function("scalar-native", |b| {
        b.iter(|| black_box((targets.scalar_native)(black_box(input))))
    });

    group.bench_function("simdeez-runtime", |b| {
        b.iter(|| black_box((targets.simdeez_runtime)(black_box(input))))
    });

    group.bench_function("simdeez-forced-scalar", |b| {
        b.iter(|| black_box((targets.simdeez_scalar)(black_box(input))))
    });

    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if std::is_x86_feature_detected!("sse2") {
            group.bench_function("simdeez-forced-sse2", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_sse2)(black_box(input))) })
            });
        } else {
            eprintln!("[bench] skipped simdeez-forced-sse2 for {group_name}: CPU lacks sse2");
        }

        if std::is_x86_feature_detected!("sse4.1") {
            group.bench_function("simdeez-forced-sse41", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_sse41)(black_box(input))) })
            });
        } else {
            eprintln!("[bench] skipped simdeez-forced-sse41 for {group_name}: CPU lacks sse4.1");
        }

        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
            group.bench_function("simdeez-forced-avx2", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_avx2)(black_box(input))) })
            });
        } else {
            eprintln!("[bench] skipped simdeez-forced-avx2 for {group_name}: CPU lacks avx2/fma");
        }

        if std::is_x86_feature_detected!("avx512f")
            && std::is_x86_feature_detected!("avx512bw")
            && std::is_x86_feature_detected!("avx512dq")
        {
            group.bench_function("simdeez-forced-avx512", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_avx512)(black_box(input))) })
            });
        } else {
            eprintln!(
                "[bench] skipped simdeez-forced-avx512 for {group_name}: CPU lacks avx512f+bw+dq"
            );
        }
    }

    group.finish();
}

type ScalarVf32 = <Scalar as Simd>::Vf32;
type ScalarVf64 = <Scalar as Simd>::Vf64;

#[inline(never)]
pub fn force_scalar_sum(input: &[f32], op: impl Fn(ScalarVf32) -> ScalarVf32) -> f32 {
    simdeez_sum_impl::<Scalar>(input, op)
}

#[inline(never)]
pub fn force_scalar_sum_f64(input: &[f64], op: impl Fn(ScalarVf64) -> ScalarVf64) -> f64 {
    simdeez_sum_impl_f64::<Scalar>(input, op)
}

#[inline(always)]
pub fn simdeez_sum_impl<S: Simd>(input: &[f32], op: impl Fn(S::Vf32) -> S::Vf32) -> f32 {
    let mut sum = 0.0f32;
    let mut i = 0;

    while i + S::Vf32::WIDTH <= input.len() {
        let v = S::Vf32::load_from_slice(&input[i..]);
        sum += op(v).horizontal_add();
        i += S::Vf32::WIDTH;
    }

    sum
}

#[inline(always)]
pub fn simdeez_sum_impl_f64<S: Simd>(input: &[f64], op: impl Fn(S::Vf64) -> S::Vf64) -> f64 {
    let mut sum = 0.0f64;
    let mut i = 0;

    while i + S::Vf64::WIDTH <= input.len() {
        let v = S::Vf64::load_from_slice(&input[i..]);
        sum += op(v).horizontal_add();
        i += S::Vf64::WIDTH;
    }

    sum
}