simdeez 3.0.1

SIMD library to abstract over different instruction sets and widths
Documentation
use criterion::{Criterion, Throughput};
use rand::{Rng, SeedableRng};
use rand_chacha::ChaCha8Rng;
use simdeez::prelude::*;
use simdeez::scalar::Scalar;
use std::hint::black_box;

pub const INPUT_LEN: usize = 1 << 20;

pub fn make_unary_inputs(len: usize, seed: u64, range: core::ops::Range<f32>) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(range.clone())).collect()
}

pub fn make_positive_inputs(len: usize, seed: u64, min: f32, max: f32) -> Vec<f32> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(min..max)).collect()
}

pub fn make_unary_inputs_f64(len: usize, seed: u64, range: core::ops::Range<f64>) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(range.clone())).collect()
}

pub fn make_positive_inputs_f64(len: usize, seed: u64, min: f64, max: f64) -> Vec<f64> {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    (0..len).map(|_| rng.gen_range(min..max)).collect()
}

pub fn make_binary_inputs(
    len: usize,
    seed: u64,
    range_a: core::ops::Range<f32>,
    range_b: core::ops::Range<f32>,
) -> (Vec<f32>, Vec<f32>) {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    let a = (0..len).map(|_| rng.gen_range(range_a.clone())).collect();
    let b = (0..len).map(|_| rng.gen_range(range_b.clone())).collect();
    (a, b)
}

pub fn make_binary_inputs_f64(
    len: usize,
    seed: u64,
    range_a: core::ops::Range<f64>,
    range_b: core::ops::Range<f64>,
) -> (Vec<f64>, Vec<f64>) {
    let mut rng = ChaCha8Rng::seed_from_u64(seed);
    let a = (0..len).map(|_| rng.gen_range(range_a.clone())).collect();
    let b = (0..len).map(|_| rng.gen_range(range_b.clone())).collect();
    (a, b)
}

#[inline(always)]
pub fn simdeez_unary_sum_impl<S: Simd>(input: &[f32], op: impl Fn(S::Vf32) -> S::Vf32) -> f32 {
    let mut sum = 0.0f32;
    let mut i = 0;
    while i + S::Vf32::WIDTH <= input.len() {
        let v = S::Vf32::load_from_slice(&input[i..]);
        sum += op(v).horizontal_add();
        i += S::Vf32::WIDTH;
    }
    sum
}

#[inline(always)]
pub fn simdeez_unary_sum_impl_f64<S: Simd>(input: &[f64], op: impl Fn(S::Vf64) -> S::Vf64) -> f64 {
    let mut sum = 0.0f64;
    let mut i = 0;
    while i + S::Vf64::WIDTH <= input.len() {
        let v = S::Vf64::load_from_slice(&input[i..]);
        sum += op(v).horizontal_add();
        i += S::Vf64::WIDTH;
    }
    sum
}

#[inline(always)]
pub fn simdeez_binary_sum_impl<S: Simd>(
    a: &[f32],
    b: &[f32],
    op: impl Fn(S::Vf32, S::Vf32) -> S::Vf32,
) -> f32 {
    let mut sum = 0.0f32;
    let mut i = 0;
    while i + S::Vf32::WIDTH <= a.len() {
        let va = S::Vf32::load_from_slice(&a[i..]);
        let vb = S::Vf32::load_from_slice(&b[i..]);
        sum += op(va, vb).horizontal_add();
        i += S::Vf32::WIDTH;
    }
    sum
}

#[inline(always)]
pub fn simdeez_binary_sum_impl_f64<S: Simd>(
    a: &[f64],
    b: &[f64],
    op: impl Fn(S::Vf64, S::Vf64) -> S::Vf64,
) -> f64 {
    let mut sum = 0.0f64;
    let mut i = 0;
    while i + S::Vf64::WIDTH <= a.len() {
        let va = S::Vf64::load_from_slice(&a[i..]);
        let vb = S::Vf64::load_from_slice(&b[i..]);
        sum += op(va, vb).horizontal_add();
        i += S::Vf64::WIDTH;
    }
    sum
}

type ScalarVf32 = <Scalar as Simd>::Vf32;
type ScalarVf64 = <Scalar as Simd>::Vf64;

pub struct BenchTargets {
    pub scalar_native: fn(&[f32]) -> f32,
    pub simdeez_runtime: fn(&[f32]) -> f32,
    pub simdeez_scalar: fn(&[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse2: unsafe fn(&[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse41: unsafe fn(&[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx2: unsafe fn(&[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx512: unsafe fn(&[f32]) -> f32,
}

pub struct BenchTargetsF64 {
    pub scalar_native: fn(&[f64]) -> f64,
    pub simdeez_runtime: fn(&[f64]) -> f64,
    pub simdeez_scalar: fn(&[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse2: unsafe fn(&[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse41: unsafe fn(&[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx2: unsafe fn(&[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx512: unsafe fn(&[f64]) -> f64,
}

pub struct BinaryBenchTargets {
    pub scalar_native: fn(&[f32], &[f32]) -> f32,
    pub simdeez_runtime: fn(&[f32], &[f32]) -> f32,
    pub simdeez_scalar: fn(&[f32], &[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse2: unsafe fn(&[f32], &[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse41: unsafe fn(&[f32], &[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx2: unsafe fn(&[f32], &[f32]) -> f32,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx512: unsafe fn(&[f32], &[f32]) -> f32,
}

pub struct BinaryBenchTargetsF64 {
    pub scalar_native: fn(&[f64], &[f64]) -> f64,
    pub simdeez_runtime: fn(&[f64], &[f64]) -> f64,
    pub simdeez_scalar: fn(&[f64], &[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse2: unsafe fn(&[f64], &[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_sse41: unsafe fn(&[f64], &[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx2: unsafe fn(&[f64], &[f64]) -> f64,
    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    pub simdeez_avx512: unsafe fn(&[f64], &[f64]) -> f64,
}

#[inline(never)]
pub fn force_scalar_sum(input: &[f32], op: impl Fn(ScalarVf32) -> ScalarVf32) -> f32 {
    simdeez_unary_sum_impl::<Scalar>(input, op)
}

#[inline(never)]
pub fn force_scalar_sum_f64(input: &[f64], op: impl Fn(ScalarVf64) -> ScalarVf64) -> f64 {
    simdeez_unary_sum_impl_f64::<Scalar>(input, op)
}

#[inline(never)]
pub fn force_scalar_binary_sum(
    a: &[f32],
    b: &[f32],
    op: impl Fn(ScalarVf32, ScalarVf32) -> ScalarVf32,
) -> f32 {
    simdeez_binary_sum_impl::<Scalar>(a, b, op)
}

#[inline(never)]
pub fn force_scalar_binary_sum_f64(
    a: &[f64],
    b: &[f64],
    op: impl Fn(ScalarVf64, ScalarVf64) -> ScalarVf64,
) -> f64 {
    simdeez_binary_sum_impl_f64::<Scalar>(a, b, op)
}

pub fn bench_unary_variants(c: &mut Criterion, name: &str, input: &[f32], targets: BenchTargets) {
    let mut group = c.benchmark_group(name);
    group.throughput(Throughput::Elements(input.len() as u64));
    group.bench_function("scalar-native", |b| {
        b.iter(|| black_box((targets.scalar_native)(black_box(input))))
    });
    group.bench_function("simdeez-runtime", |b| {
        b.iter(|| black_box((targets.simdeez_runtime)(black_box(input))))
    });
    group.bench_function("simdeez-forced-scalar", |b| {
        b.iter(|| black_box((targets.simdeez_scalar)(black_box(input))))
    });

    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if std::is_x86_feature_detected!("sse2") {
            group.bench_function("simdeez-forced-sse2", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_sse2)(black_box(input))) })
            });
        }
        if std::is_x86_feature_detected!("sse4.1") {
            group.bench_function("simdeez-forced-sse41", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_sse41)(black_box(input))) })
            });
        }
        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
            group.bench_function("simdeez-forced-avx2", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_avx2)(black_box(input))) })
            });
        }
        if std::is_x86_feature_detected!("avx512f")
            && std::is_x86_feature_detected!("avx512bw")
            && std::is_x86_feature_detected!("avx512dq")
        {
            group.bench_function("simdeez-forced-avx512", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_avx512)(black_box(input))) })
            });
        }
    }

    group.finish();
}

pub fn bench_unary_variants_f64(
    c: &mut Criterion,
    name: &str,
    input: &[f64],
    targets: BenchTargetsF64,
) {
    let mut group = c.benchmark_group(name);
    group.throughput(Throughput::Elements(input.len() as u64));
    group.bench_function("scalar-native", |b| {
        b.iter(|| black_box((targets.scalar_native)(black_box(input))))
    });
    group.bench_function("simdeez-runtime", |b| {
        b.iter(|| black_box((targets.simdeez_runtime)(black_box(input))))
    });
    group.bench_function("simdeez-forced-scalar", |b| {
        b.iter(|| black_box((targets.simdeez_scalar)(black_box(input))))
    });

    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if std::is_x86_feature_detected!("sse2") {
            group.bench_function("simdeez-forced-sse2", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_sse2)(black_box(input))) })
            });
        }
        if std::is_x86_feature_detected!("sse4.1") {
            group.bench_function("simdeez-forced-sse41", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_sse41)(black_box(input))) })
            });
        }
        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
            group.bench_function("simdeez-forced-avx2", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_avx2)(black_box(input))) })
            });
        }
        if std::is_x86_feature_detected!("avx512f")
            && std::is_x86_feature_detected!("avx512bw")
            && std::is_x86_feature_detected!("avx512dq")
        {
            group.bench_function("simdeez-forced-avx512", |b| {
                b.iter(|| unsafe { black_box((targets.simdeez_avx512)(black_box(input))) })
            });
        }
    }

    group.finish();
}

pub fn bench_binary_variants(
    c: &mut Criterion,
    name: &str,
    a: &[f32],
    b: &[f32],
    targets: BinaryBenchTargets,
) {
    let mut group = c.benchmark_group(name);
    group.throughput(Throughput::Elements(a.len() as u64));
    group.bench_function("scalar-native", |ben| {
        ben.iter(|| black_box((targets.scalar_native)(black_box(a), black_box(b))))
    });
    group.bench_function("simdeez-runtime", |ben| {
        ben.iter(|| black_box((targets.simdeez_runtime)(black_box(a), black_box(b))))
    });
    group.bench_function("simdeez-forced-scalar", |ben| {
        ben.iter(|| black_box((targets.simdeez_scalar)(black_box(a), black_box(b))))
    });

    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if std::is_x86_feature_detected!("sse2") {
            group.bench_function("simdeez-forced-sse2", |ben| {
                ben.iter(|| unsafe {
                    black_box((targets.simdeez_sse2)(black_box(a), black_box(b)))
                })
            });
        }
        if std::is_x86_feature_detected!("sse4.1") {
            group.bench_function("simdeez-forced-sse41", |ben| {
                ben.iter(|| unsafe {
                    black_box((targets.simdeez_sse41)(black_box(a), black_box(b)))
                })
            });
        }
        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
            group.bench_function("simdeez-forced-avx2", |ben| {
                ben.iter(|| unsafe {
                    black_box((targets.simdeez_avx2)(black_box(a), black_box(b)))
                })
            });
        }
        if std::is_x86_feature_detected!("avx512f")
            && std::is_x86_feature_detected!("avx512bw")
            && std::is_x86_feature_detected!("avx512dq")
        {
            group.bench_function("simdeez-forced-avx512", |ben| {
                ben.iter(|| unsafe {
                    black_box((targets.simdeez_avx512)(black_box(a), black_box(b)))
                })
            });
        }
    }

    group.finish();
}

pub fn bench_binary_variants_f64(
    c: &mut Criterion,
    name: &str,
    a: &[f64],
    b: &[f64],
    targets: BinaryBenchTargetsF64,
) {
    let mut group = c.benchmark_group(name);
    group.throughput(Throughput::Elements(a.len() as u64));
    group.bench_function("scalar-native", |ben| {
        ben.iter(|| black_box((targets.scalar_native)(black_box(a), black_box(b))))
    });
    group.bench_function("simdeez-runtime", |ben| {
        ben.iter(|| black_box((targets.simdeez_runtime)(black_box(a), black_box(b))))
    });
    group.bench_function("simdeez-forced-scalar", |ben| {
        ben.iter(|| black_box((targets.simdeez_scalar)(black_box(a), black_box(b))))
    });

    #[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
    {
        if std::is_x86_feature_detected!("sse2") {
            group.bench_function("simdeez-forced-sse2", |ben| {
                ben.iter(|| unsafe {
                    black_box((targets.simdeez_sse2)(black_box(a), black_box(b)))
                })
            });
        }
        if std::is_x86_feature_detected!("sse4.1") {
            group.bench_function("simdeez-forced-sse41", |ben| {
                ben.iter(|| unsafe {
                    black_box((targets.simdeez_sse41)(black_box(a), black_box(b)))
                })
            });
        }
        if std::is_x86_feature_detected!("avx2") && std::is_x86_feature_detected!("fma") {
            group.bench_function("simdeez-forced-avx2", |ben| {
                ben.iter(|| unsafe {
                    black_box((targets.simdeez_avx2)(black_box(a), black_box(b)))
                })
            });
        }
        if std::is_x86_feature_detected!("avx512f")
            && std::is_x86_feature_detected!("avx512bw")
            && std::is_x86_feature_detected!("avx512dq")
        {
            group.bench_function("simdeez-forced-avx512", |ben| {
                ben.iter(|| unsafe {
                    black_box((targets.simdeez_avx512)(black_box(a), black_box(b)))
                })
            });
        }
    }

    group.finish();
}