ew_impl_wrap!(
f32,
x86_64_avx512_erf_f32_64n,
64,
16,
(),
#[inline(never)]
fn run(buf: &mut [f32], _: ()) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
unsafe { x86_64_avx512_erf_f32_64n_run(buf) }
}
);
#[target_feature(enable = "avx512f")]
unsafe fn x86_64_avx512_erf_f32_64n_run(buf: &mut [f32]) {
unsafe {
let len = buf.len();
let ptr = buf.as_ptr();
const A1: f32 = 0.0705230784;
const A2: f32 = 0.0422820123;
const A3: f32 = 0.0092705272;
const A4: f32 = 0.0001520143;
const A5: f32 = 0.0002765672;
const A6: f32 = 0.0000430638;
const ABS_MASK: f32 = f32::from_bits(0x7fffffff);
const SIGN_MASK: f32 = f32::from_bits(0x80000000);
std::arch::asm!("
// broadcast constants (xmmN -> zmmN, broadcast across all 16 lanes)
vbroadcastss zmm0, xmm0 // a1
vbroadcastss zmm1, xmm1 // a2
vbroadcastss zmm2, xmm2 // a3
vbroadcastss zmm3, xmm3 // a4
vbroadcastss zmm4, xmm4 // a5
vbroadcastss zmm5, xmm5 // a6
vbroadcastss zmm6, xmm6 // 1.0
vbroadcastss zmm7, xmm7 // abs mask (0x7fffffff)
vbroadcastss zmm8, xmm8 // sign mask (0x80000000)
2:
// load 4 zmm of input
vmovaps zmm9, [{ptr}]
vmovaps zmm10, [{ptr} + 64]
vmovaps zmm11, [{ptr} + 128]
vmovaps zmm12, [{ptr} + 192]
// sign[i] = x[i] & SIGN_MASK (keeps only the sign bit)
vandps zmm13, zmm9, zmm8
vandps zmm14, zmm10, zmm8
vandps zmm15, zmm11, zmm8
vandps zmm16, zmm12, zmm8
// abs[i] = x[i] & ABS_MASK (clears the sign bit)
vandps zmm9, zmm9, zmm7
vandps zmm10, zmm10, zmm7
vandps zmm11, zmm11, zmm7
vandps zmm12, zmm12, zmm7
// y = a6 (in zmm17..20, 4 independent channels)
vmovaps zmm17, zmm5
vmovaps zmm18, zmm5
vmovaps zmm19, zmm5
vmovaps zmm20, zmm5
// y = y*abs + a5
vfmadd213ps zmm17, zmm9, zmm4
vfmadd213ps zmm18, zmm10, zmm4
vfmadd213ps zmm19, zmm11, zmm4
vfmadd213ps zmm20, zmm12, zmm4
// y = y*abs + a4
vfmadd213ps zmm17, zmm9, zmm3
vfmadd213ps zmm18, zmm10, zmm3
vfmadd213ps zmm19, zmm11, zmm3
vfmadd213ps zmm20, zmm12, zmm3
// y = y*abs + a3
vfmadd213ps zmm17, zmm9, zmm2
vfmadd213ps zmm18, zmm10, zmm2
vfmadd213ps zmm19, zmm11, zmm2
vfmadd213ps zmm20, zmm12, zmm2
// y = y*abs + a2
vfmadd213ps zmm17, zmm9, zmm1
vfmadd213ps zmm18, zmm10, zmm1
vfmadd213ps zmm19, zmm11, zmm1
vfmadd213ps zmm20, zmm12, zmm1
// y = y*abs + a1
vfmadd213ps zmm17, zmm9, zmm0
vfmadd213ps zmm18, zmm10, zmm0
vfmadd213ps zmm19, zmm11, zmm0
vfmadd213ps zmm20, zmm12, zmm0
// y = y * abs (final factor)
vmulps zmm17, zmm17, zmm9
vmulps zmm18, zmm18, zmm10
vmulps zmm19, zmm19, zmm11
vmulps zmm20, zmm20, zmm12
// y = y + 1
vaddps zmm17, zmm17, zmm6
vaddps zmm18, zmm18, zmm6
vaddps zmm19, zmm19, zmm6
vaddps zmm20, zmm20, zmm6
// y^16: square 4 times
vmulps zmm17, zmm17, zmm17
vmulps zmm18, zmm18, zmm18
vmulps zmm19, zmm19, zmm19
vmulps zmm20, zmm20, zmm20
vmulps zmm17, zmm17, zmm17
vmulps zmm18, zmm18, zmm18
vmulps zmm19, zmm19, zmm19
vmulps zmm20, zmm20, zmm20
vmulps zmm17, zmm17, zmm17
vmulps zmm18, zmm18, zmm18
vmulps zmm19, zmm19, zmm19
vmulps zmm20, zmm20, zmm20
vmulps zmm17, zmm17, zmm17
vmulps zmm18, zmm18, zmm18
vmulps zmm19, zmm19, zmm19
vmulps zmm20, zmm20, zmm20
// y = 1 / y (full-precision reciprocal, matches generic .recip())
vdivps zmm21, zmm6, zmm17
vdivps zmm22, zmm6, zmm18
vdivps zmm23, zmm6, zmm19
vdivps zmm24, zmm6, zmm20
// y = 1 - y
vsubps zmm21, zmm6, zmm21
vsubps zmm22, zmm6, zmm22
vsubps zmm23, zmm6, zmm23
vsubps zmm24, zmm6, zmm24
// copysign: stamp the original sign bit onto the (positive) result
vorps zmm21, zmm21, zmm13
vorps zmm22, zmm22, zmm14
vorps zmm23, zmm23, zmm15
vorps zmm24, zmm24, zmm16
// store
vmovaps [{ptr}], zmm21
vmovaps [{ptr} + 64], zmm22
vmovaps [{ptr} + 128], zmm23
vmovaps [{ptr} + 192], zmm24
add {ptr}, 256
sub {len}, 64
jnz 2b
",
len = inout(reg) len => _,
ptr = inout(reg) ptr => _,
inout("xmm0") A1 => _,
inout("xmm1") A2 => _,
inout("xmm2") A3 => _,
inout("xmm3") A4 => _,
inout("xmm4") A5 => _,
inout("xmm5") A6 => _,
inout("xmm6") 1f32 => _,
inout("xmm7") ABS_MASK => _,
inout("xmm8") SIGN_MASK => _,
out("zmm9") _, out("zmm10") _, out("zmm11") _, out("zmm12") _,
out("zmm13") _, out("zmm14") _, out("zmm15") _, out("zmm16") _,
out("zmm17") _, out("zmm18") _, out("zmm19") _, out("zmm20") _,
out("zmm21") _, out("zmm22") _, out("zmm23") _, out("zmm24") _,
);
}
}
#[cfg(test)]
pub mod test_x86_64_avx512_erf_f32_64n {
use super::*;
crate::erf_frame_tests!(is_x86_feature_detected!("avx512f"), f32, x86_64_avx512_erf_f32_64n);
}