use tract_data::internal::f16;
#[repr(C, align(64))]
struct AlignedScratch([f32; 256]);
impl AlignedScratch {
fn new() -> Self {
Self([0f32; 256])
}
}
const CHUNK: usize = 256;
#[target_feature(enable = "avx512f")]
unsafe fn cvt_f16_to_f32(src: &[f16], dst: &mut [f32]) {
use core::arch::x86_64::*;
let n = src.len();
debug_assert!(dst.len() >= n);
let chunks = n / 16;
unsafe {
for k in 0..chunks {
let m = _mm256_loadu_si256(src.as_ptr().add(k * 16) as *const __m256i);
let z = _mm512_cvtph_ps(m);
_mm512_storeu_ps(dst.as_mut_ptr().add(k * 16), z);
}
for k in (chunks * 16)..n {
*dst.get_unchecked_mut(k) = src.get_unchecked(k).to_f32();
}
}
}
#[target_feature(enable = "avx512f")]
unsafe fn cvt_f32_to_f16(src: &[f32], dst: &mut [f16]) {
use core::arch::x86_64::*;
let n = src.len();
debug_assert!(dst.len() >= n);
let chunks = n / 16;
unsafe {
for k in 0..chunks {
let z = _mm512_loadu_ps(src.as_ptr().add(k * 16));
let m = _mm512_cvtps_ph::<0>(z);
_mm256_storeu_si256(dst.as_mut_ptr().add(k * 16) as *mut __m256i, m);
}
for k in (chunks * 16)..n {
*dst.get_unchecked_mut(k) = f16::from_f32(*src.get_unchecked(k));
}
}
}
ew_impl_wrap!(
f16,
x86_64_avx512_hardswish_f16_64n,
64,
32,
(),
#[inline(never)]
fn run(buf: &mut [f16], _: ()) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
let mut scratch = AlignedScratch::new();
let s = &mut scratch.0;
let mut i = 0;
while i < buf.len() {
let n = (CHUNK).min(buf.len() - i);
unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut s[..n]) };
super::act::x86_64_avx512_hardswish_f32_64n::run(&mut s[..n], ());
unsafe { cvt_f32_to_f16(&s[..n], &mut buf[i..i + n]) };
i += n;
}
}
);
#[cfg(test)]
pub mod test_x86_64_avx512_hardswish_f16_64n {
use super::*;
hardswish_frame_tests!(
is_x86_feature_detected!("avx512f"),
f16,
x86_64_avx512_hardswish_f16_64n
);
}
ew_impl_wrap!(
f16,
x86_64_avx512_leaky_relu_f16_64n,
64,
32,
f16,
#[inline(never)]
fn run(buf: &mut [f16], alpha: f16) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
let alpha_f32 = alpha.to_f32();
let mut scratch = AlignedScratch::new();
let s = &mut scratch.0;
let mut i = 0;
while i < buf.len() {
let n = (CHUNK).min(buf.len() - i);
unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut s[..n]) };
super::act::x86_64_avx512_leaky_relu_f32_64n::run(&mut s[..n], alpha_f32);
unsafe { cvt_f32_to_f16(&s[..n], &mut buf[i..i + n]) };
i += n;
}
}
);
#[cfg(test)]
pub mod test_x86_64_avx512_leaky_relu_f16_64n {
use super::*;
leaky_relu_frame_tests!(
is_x86_feature_detected!("avx512f"),
f16,
x86_64_avx512_leaky_relu_f16_64n
);
}
ew_impl_wrap!(
f16,
x86_64_avx512_sigmoid_f16_16n,
16,
16,
(),
#[inline(never)]
fn run(buf: &mut [f16], _: ()) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
let mut scratch = AlignedScratch::new();
let s = &mut scratch.0;
let mut i = 0;
while i < buf.len() {
let n = (CHUNK).min(buf.len() - i);
unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut s[..n]) };
super::avx512_sigmoid_f32::run(&mut s[..n], ());
unsafe { cvt_f32_to_f16(&s[..n], &mut buf[i..i + n]) };
i += n;
}
}
);
#[cfg(test)]
pub mod test_x86_64_avx512_sigmoid_f16_16n {
use super::*;
sigmoid_frame_tests!(is_x86_feature_detected!("avx512f"), f16, x86_64_avx512_sigmoid_f16_16n);
}
ew_impl_wrap!(
f16,
x86_64_avx512_tanh_f16_16n,
16,
16,
(),
#[inline(never)]
fn run(buf: &mut [f16], _: ()) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
let mut scratch = AlignedScratch::new();
let s = &mut scratch.0;
let mut i = 0;
while i < buf.len() {
let n = (CHUNK).min(buf.len() - i);
unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut s[..n]) };
super::avx512_tanh_f32::run(&mut s[..n], ());
unsafe { cvt_f32_to_f16(&s[..n], &mut buf[i..i + n]) };
i += n;
}
}
);
#[cfg(test)]
pub mod test_x86_64_avx512_tanh_f16_16n {
use super::*;
tanh_frame_tests!(is_x86_feature_detected!("avx512f"), f16, x86_64_avx512_tanh_f16_16n);
}
ew_impl_wrap!(
f16,
x86_64_avx512_silu_f16_16n,
16,
16,
(),
#[inline(never)]
fn run(buf: &mut [f16], _: ()) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
let mut work = AlignedScratch::new();
let mut save = AlignedScratch::new();
let w = &mut work.0;
let v = &mut save.0;
let mut i = 0;
while i < buf.len() {
let n = (CHUNK).min(buf.len() - i);
unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut w[..n]) };
v[..n].copy_from_slice(&w[..n]);
super::avx512_sigmoid_f32::run(&mut w[..n], ());
for j in 0..n {
w[j] *= v[j];
}
unsafe { cvt_f32_to_f16(&w[..n], &mut buf[i..i + n]) };
i += n;
}
}
);
#[cfg(test)]
pub mod test_x86_64_avx512_silu_f16_16n {
use super::*;
silu_frame_tests!(is_x86_feature_detected!("avx512f"), f16, x86_64_avx512_silu_f16_16n);
}
ew_impl_wrap!(
f16,
x86_64_avx512_gelu_f16_16n,
16,
16,
(),
#[inline(never)]
fn run(buf: &mut [f16], _: ()) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
const SQRT_2_OVER_PI: f32 = 0.7978845608028654;
const COEF: f32 = 0.044715;
let mut work = AlignedScratch::new();
let mut save = AlignedScratch::new();
let w = &mut work.0;
let v = &mut save.0;
let mut i = 0;
while i < buf.len() {
let n = (CHUNK).min(buf.len() - i);
unsafe { cvt_f16_to_f32(&buf[i..i + n], &mut v[..n]) };
for j in 0..n {
let x = v[j];
w[j] = SQRT_2_OVER_PI * (x + COEF * x * x * x);
}
super::avx512_tanh_f32::run(&mut w[..n], ());
for j in 0..n {
w[j] = 0.5 * v[j] * (1.0 + w[j]);
}
unsafe { cvt_f32_to_f16(&w[..n], &mut buf[i..i + n]) };
i += n;
}
}
);
#[cfg(test)]
pub mod test_x86_64_avx512_gelu_f16_16n {
use super::*;
gelu_frame_tests!(is_x86_feature_detected!("avx512f"), f16, x86_64_avx512_gelu_f16_16n);
}