simdeez 3.0.1

SIMD library to abstract over different instruction sets and widths
Documentation
#![allow(dead_code, unused_parens)]

use crate::prelude::*;

fn assert_f32_min_max_contract(a: f32, b: f32, min: f32, max: f32) {
    if a.is_nan() || b.is_nan() {
        assert!(min.is_nan() || min == a || min == b);
        assert!(max.is_nan() || max == a || max == b);
        return;
    }

    if a == 0.0 && b == 0.0 {
        assert_eq!(min, 0.0);
        assert_eq!(max, 0.0);
        return;
    }

    assert_eq!(min.to_bits(), scalar_simd_min_f32(a, b).to_bits());
    assert_eq!(max.to_bits(), scalar_simd_max_f32(a, b).to_bits());
}

fn assert_f64_min_max_contract(a: f64, b: f64, min: f64, max: f64) {
    if a.is_nan() || b.is_nan() {
        assert!(min.is_nan() || min == a || min == b);
        assert!(max.is_nan() || max == a || max == b);
        return;
    }

    if a == 0.0 && b == 0.0 {
        assert_eq!(min, 0.0);
        assert_eq!(max, 0.0);
        return;
    }

    assert_eq!(min.to_bits(), scalar_simd_min_f64(a, b).to_bits());
    assert_eq!(max.to_bits(), scalar_simd_max_f64(a, b).to_bits());
}

#[inline]
fn scalar_simd_min_f32(a: f32, b: f32) -> f32 {
    if a < b {
        a
    } else {
        b
    }
}

#[inline]
fn scalar_simd_max_f32(a: f32, b: f32) -> f32 {
    if a > b {
        a
    } else {
        b
    }
}

#[inline]
fn scalar_simd_min_f64(a: f64, b: f64) -> f64 {
    if a < b {
        a
    } else {
        b
    }
}

#[inline]
fn scalar_simd_max_f64(a: f64, b: f64) -> f64 {
    if a > b {
        a
    } else {
        b
    }
}

simd_unsafe_generate_all!(
    fn min_max_f32(lhs: &[f32], rhs: &[f32], min_out: &mut [f32], max_out: &mut [f32]) {
        assert_eq!(lhs.len(), rhs.len());
        assert_eq!(lhs.len(), min_out.len());
        assert_eq!(lhs.len(), max_out.len());

        let mut lhs = lhs;
        let mut rhs = rhs;
        let mut min_out = min_out;
        let mut max_out = max_out;

        while lhs.len() >= S::Vf32::WIDTH {
            let a = S::Vf32::load_from_slice(lhs);
            let b = S::Vf32::load_from_slice(rhs);
            a.min(b).copy_to_slice(min_out);
            a.max(b).copy_to_slice(max_out);

            lhs = &lhs[S::Vf32::WIDTH..];
            rhs = &rhs[S::Vf32::WIDTH..];
            min_out = &mut min_out[S::Vf32::WIDTH..];
            max_out = &mut max_out[S::Vf32::WIDTH..];
        }

        for ((&a, &b), (min_slot, max_slot)) in lhs
            .iter()
            .zip(rhs.iter())
            .zip(min_out.iter_mut().zip(max_out.iter_mut()))
        {
            *min_slot = scalar_simd_min_f32(a, b);
            *max_slot = scalar_simd_max_f32(a, b);
        }
    }
);

simd_unsafe_generate_all!(
    fn min_max_f64(lhs: &[f64], rhs: &[f64], min_out: &mut [f64], max_out: &mut [f64]) {
        assert_eq!(lhs.len(), rhs.len());
        assert_eq!(lhs.len(), min_out.len());
        assert_eq!(lhs.len(), max_out.len());

        let mut lhs = lhs;
        let mut rhs = rhs;
        let mut min_out = min_out;
        let mut max_out = max_out;

        while lhs.len() >= S::Vf64::WIDTH {
            let a = S::Vf64::load_from_slice(lhs);
            let b = S::Vf64::load_from_slice(rhs);
            a.min(b).copy_to_slice(min_out);
            a.max(b).copy_to_slice(max_out);

            lhs = &lhs[S::Vf64::WIDTH..];
            rhs = &rhs[S::Vf64::WIDTH..];
            min_out = &mut min_out[S::Vf64::WIDTH..];
            max_out = &mut max_out[S::Vf64::WIDTH..];
        }

        for ((&a, &b), (min_slot, max_slot)) in lhs
            .iter()
            .zip(rhs.iter())
            .zip(min_out.iter_mut().zip(max_out.iter_mut()))
        {
            *min_slot = scalar_simd_min_f64(a, b);
            *max_slot = scalar_simd_max_f64(a, b);
        }
    }
);

simd_unsafe_generate_all!(
    fn rsqrt_f32(values: &[f32], out: &mut [f32]) {
        assert_eq!(values.len(), out.len());

        let mut values = values;
        let mut out = out;

        while values.len() >= S::Vf32::WIDTH {
            let a = S::Vf32::load_from_slice(values);
            a.rsqrt().copy_to_slice(out);

            values = &values[S::Vf32::WIDTH..];
            out = &mut out[S::Vf32::WIDTH..];
        }

        for (&value, slot) in values.iter().zip(out.iter_mut()) {
            *slot = 1.0 / value.sqrt();
        }
    }
);

simd_unsafe_generate_all!(
    fn rsqrt_f64(values: &[f64], out: &mut [f64]) {
        assert_eq!(values.len(), out.len());

        let mut values = values;
        let mut out = out;

        while values.len() >= S::Vf64::WIDTH {
            let a = S::Vf64::load_from_slice(values);
            a.rsqrt().copy_to_slice(out);

            values = &values[S::Vf64::WIDTH..];
            out = &mut out[S::Vf64::WIDTH..];
        }

        for (&value, slot) in values.iter().zip(out.iter_mut()) {
            *slot = 1.0 / value.sqrt();
        }
    }
);

#[test]
fn float_min_max_nan_and_signed_zero_contract_for_f32() {
    let lhs = vec![f32::NAN, 3.0, -0.0, 0.0, -5.0, 8.0, f32::NAN, -2.0];
    let rhs = vec![4.0, f32::NAN, 0.0, -0.0, -5.0, 3.0, f32::NAN, -3.0];

    let mut min = vec![0.0; lhs.len()];
    let mut max = vec![0.0; lhs.len()];

    min_max_f32(&lhs, &rhs, &mut min, &mut max);

    for (((&a, &b), &min), &max) in lhs.iter().zip(rhs.iter()).zip(min.iter()).zip(max.iter()) {
        assert_f32_min_max_contract(a, b, min, max);
    }
}

#[test]
fn float_min_max_nan_and_signed_zero_contract_for_f64() {
    let lhs = vec![f64::NAN, 3.0, -0.0, 0.0, -5.0, 8.0, f64::NAN, -2.0];
    let rhs = vec![4.0, f64::NAN, 0.0, -0.0, -5.0, 3.0, f64::NAN, -3.0];

    let mut min = vec![0.0; lhs.len()];
    let mut max = vec![0.0; lhs.len()];

    min_max_f64(&lhs, &rhs, &mut min, &mut max);

    for (((&a, &b), &min), &max) in lhs.iter().zip(rhs.iter()).zip(min.iter()).zip(max.iter()) {
        assert_f64_min_max_contract(a, b, min, max);
    }
}

#[test]
fn float_rsqrt_contract_is_for_finite_positive_inputs_f32() {
    let input = vec![0.03125_f32, 0.125, 0.5, 1.0, 2.0, 9.0, 64.0, 255.0];
    let expected = input.iter().map(|&x| 1.0f32 / x.sqrt()).collect::<Vec<_>>();

    let mut output = vec![0.0; input.len()];
    rsqrt_f32(&input, &mut output);

    for (&actual, &expected) in output.iter().zip(expected.iter()) {
        let rel = (actual - expected).abs() / expected.max(1.0);
        assert!(
            rel <= 0.02,
            "f32 rsqrt relative error too large: actual={actual}, expected={expected}, rel={rel}"
        );
    }
}

#[test]
fn float_rsqrt_contract_is_for_finite_positive_inputs_f64() {
    let input = vec![0.03125_f64, 0.125, 0.5, 1.0, 2.0, 9.0, 64.0, 255.0];
    let expected = input.iter().map(|&x| 1.0f64 / x.sqrt()).collect::<Vec<_>>();

    let mut output = vec![0.0; input.len()];
    rsqrt_f64(&input, &mut output);

    for (&actual, &expected) in output.iter().zip(expected.iter()) {
        let rel = (actual - expected).abs() / expected.max(1.0);
        assert!(
            rel <= 0.02,
            "f64 rsqrt relative error too large: actual={actual}, expected={expected}, rel={rel}"
        );
    }
}