use tract_data::internal::f16;
const FP16_TARGETS: &str = "avx512f,avx512fp16,avx512bw";
ew_impl_wrap!(
f16,
x86_64_avx512fp16_hardswish_f16_128n,
128,
32,
(),
#[inline(never)]
fn run(buf: &mut [f16], _: ()) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
unsafe { hardswish_f16_run(buf) }
}
);
#[target_feature(enable = "avx512f,avx512fp16,avx512bw")]
unsafe fn hardswish_f16_run(buf: &mut [f16]) {
let len = buf.len();
let ptr = buf.as_ptr() as *mut u8;
let three = f16::from_f32(3.0).to_bits();
let six = f16::from_f32(6.0).to_bits();
let recip6 = f16::from_f32(1.0 / 6.0).to_bits();
unsafe {
std::arch::asm!("
vpbroadcastw zmm0, eax // 3.0
vpbroadcastw zmm1, ecx // 6.0
vpbroadcastw zmm2, edx // 1/6
vpxord zmm3, zmm3, zmm3 // 0.0
2:
vmovdqa64 zmm4, [{ptr}]
vmovdqa64 zmm5, [{ptr} + 64]
vmovdqa64 zmm6, [{ptr} + 128]
vmovdqa64 zmm7, [{ptr} + 192]
vaddph zmm8, zmm4, zmm0
vaddph zmm9, zmm5, zmm0
vaddph zmm10, zmm6, zmm0
vaddph zmm11, zmm7, zmm0
vminph zmm8, zmm8, zmm1
vminph zmm9, zmm9, zmm1
vminph zmm10, zmm10, zmm1
vminph zmm11, zmm11, zmm1
vmaxph zmm8, zmm8, zmm3
vmaxph zmm9, zmm9, zmm3
vmaxph zmm10, zmm10, zmm3
vmaxph zmm11, zmm11, zmm3
vmulph zmm8, zmm8, zmm4
vmulph zmm9, zmm9, zmm5
vmulph zmm10, zmm10, zmm6
vmulph zmm11, zmm11, zmm7
vmulph zmm8, zmm8, zmm2
vmulph zmm9, zmm9, zmm2
vmulph zmm10, zmm10, zmm2
vmulph zmm11, zmm11, zmm2
vmovdqa64 [{ptr}], zmm8
vmovdqa64 [{ptr} + 64], zmm9
vmovdqa64 [{ptr} + 128], zmm10
vmovdqa64 [{ptr} + 192], zmm11
add {ptr}, 256
sub {len}, 128
jnz 2b
",
len = inout(reg) len => _,
ptr = inout(reg) ptr => _,
in("eax") three as u32,
in("ecx") six as u32,
in("edx") recip6 as u32,
out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,
out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
);
}
}
ew_impl_wrap!(
f16,
x86_64_avx512fp16_leaky_relu_f16_128n,
128,
32,
f16,
#[inline(never)]
fn run(buf: &mut [f16], alpha: f16) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
unsafe { leaky_relu_f16_run(buf, alpha) }
}
);
#[target_feature(enable = "avx512f,avx512fp16,avx512bw")]
unsafe fn leaky_relu_f16_run(buf: &mut [f16], alpha: f16) {
let len = buf.len();
let ptr = buf.as_ptr() as *mut u8;
let alpha_bits = alpha.to_bits();
unsafe {
std::arch::asm!("
vpbroadcastw zmm0, eax // alpha
2:
vmovdqa64 zmm4, [{ptr}]
vmovdqa64 zmm5, [{ptr} + 64]
vmovdqa64 zmm6, [{ptr} + 128]
vmovdqa64 zmm7, [{ptr} + 192]
vmulph zmm8, zmm4, zmm0
vmulph zmm9, zmm5, zmm0
vmulph zmm10, zmm6, zmm0
vmulph zmm11, zmm7, zmm0
vmaxph zmm8, zmm8, zmm4
vmaxph zmm9, zmm9, zmm5
vmaxph zmm10, zmm10, zmm6
vmaxph zmm11, zmm11, zmm7
vmovdqa64 [{ptr}], zmm8
vmovdqa64 [{ptr} + 64], zmm9
vmovdqa64 [{ptr} + 128], zmm10
vmovdqa64 [{ptr} + 192], zmm11
add {ptr}, 256
sub {len}, 128
jnz 2b
",
len = inout(reg) len => _,
ptr = inout(reg) ptr => _,
in("eax") alpha_bits as u32,
out("zmm0") _,
out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
);
}
}
#[cfg(test)]
pub mod test_x86_64_avx512fp16_hardswish {
use super::*;
crate::hardswish_frame_tests!(
is_x86_feature_detected!("avx512fp16"),
f16,
x86_64_avx512fp16_hardswish_f16_128n
);
}
#[cfg(test)]
pub mod test_x86_64_avx512fp16_leaky_relu {
use super::*;
crate::leaky_relu_frame_tests!(
is_x86_feature_detected!("avx512fp16"),
f16,
x86_64_avx512fp16_leaky_relu_f16_128n
);
}
#[allow(dead_code)]
const _UNUSED: &str = FP16_TARGETS;