use super::traits::{SimdComplex, SimdVector};
use core::arch::x86_64::*;
#[allow(dead_code)]
#[inline]
pub fn has_avx512f() -> bool {
is_x86_feature_detected!("avx512f")
}
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct Avx512F64(pub __m512d);
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct Avx512F32(pub __m512);
unsafe impl Send for Avx512F64 {}
unsafe impl Sync for Avx512F64 {}
unsafe impl Send for Avx512F32 {}
unsafe impl Sync for Avx512F32 {}
impl SimdVector for Avx512F64 {
type Scalar = f64;
const LANES: usize = 8;
#[inline]
fn splat(value: f64) -> Self {
unsafe { Self(_mm512_set1_pd(value)) }
}
#[inline]
unsafe fn load_aligned(ptr: *const f64) -> Self {
unsafe { Self(_mm512_load_pd(ptr)) }
}
#[inline]
unsafe fn load_unaligned(ptr: *const f64) -> Self {
unsafe { Self(_mm512_loadu_pd(ptr)) }
}
#[inline]
unsafe fn store_aligned(self, ptr: *mut f64) {
unsafe { _mm512_store_pd(ptr, self.0) }
}
#[inline]
unsafe fn store_unaligned(self, ptr: *mut f64) {
unsafe { _mm512_storeu_pd(ptr, self.0) }
}
#[inline]
fn add(self, other: Self) -> Self {
unsafe { Self(_mm512_add_pd(self.0, other.0)) }
}
#[inline]
fn sub(self, other: Self) -> Self {
unsafe { Self(_mm512_sub_pd(self.0, other.0)) }
}
#[inline]
fn mul(self, other: Self) -> Self {
unsafe { Self(_mm512_mul_pd(self.0, other.0)) }
}
#[inline]
fn div(self, other: Self) -> Self {
unsafe { Self(_mm512_div_pd(self.0, other.0)) }
}
}
#[allow(dead_code)]
impl Avx512F64 {
#[inline]
pub fn new(a: f64, b: f64, c: f64, d: f64, e: f64, f: f64, g: f64, h: f64) -> Self {
unsafe { Self(_mm512_set_pd(h, g, f, e, d, c, b, a)) }
}
#[inline]
pub fn extract(self, idx: usize) -> f64 {
debug_assert!(idx < 8);
let mut arr = [0.0_f64; 8];
unsafe { self.store_unaligned(arr.as_mut_ptr()) };
arr[idx]
}
#[inline]
pub fn fmadd(self, a: Self, b: Self) -> Self {
unsafe { Self(_mm512_fmadd_pd(self.0, a.0, b.0)) }
}
#[inline]
pub fn fmsub(self, a: Self, b: Self) -> Self {
unsafe { Self(_mm512_fmsub_pd(self.0, a.0, b.0)) }
}
#[inline]
pub fn fnmadd(self, a: Self, b: Self) -> Self {
unsafe { Self(_mm512_fnmadd_pd(self.0, a.0, b.0)) }
}
#[inline]
pub fn fnmsub(self, a: Self, b: Self) -> Self {
unsafe { Self(_mm512_fnmsub_pd(self.0, a.0, b.0)) }
}
#[inline]
pub fn negate(self) -> Self {
unsafe {
let sign_mask = _mm512_set1_pd(-0.0);
Self(_mm512_xor_pd(self.0, sign_mask))
}
}
#[inline]
pub fn low_256(self) -> super::avx::AvxF64 {
unsafe { super::avx::AvxF64(_mm512_castpd512_pd256(self.0)) }
}
#[inline]
pub fn high_256(self) -> super::avx::AvxF64 {
unsafe { super::avx::AvxF64(_mm512_extractf64x4_pd(self.0, 1)) }
}
#[inline]
pub fn shuffle_within_lanes<const MASK: i32>(self) -> Self {
unsafe { Self(_mm512_permute_pd(self.0, MASK)) }
}
#[inline]
pub fn unpack_lo(self, other: Self) -> Self {
unsafe { Self(_mm512_unpacklo_pd(self.0, other.0)) }
}
#[inline]
pub fn unpack_hi(self, other: Self) -> Self {
unsafe { Self(_mm512_unpackhi_pd(self.0, other.0)) }
}
}
impl SimdComplex for Avx512F64 {
#[inline]
fn cmul(self, other: Self) -> Self {
unsafe {
let a_re = _mm512_permute_pd(self.0, 0b0000_0000);
let a_im = _mm512_permute_pd(self.0, 0b1111_1111);
let b_flip = _mm512_permute_pd(other.0, 0b0101_0101);
Self(_mm512_fmaddsub_pd(
a_re,
other.0,
_mm512_mul_pd(a_im, b_flip),
))
}
}
#[inline]
fn cmul_conj(self, other: Self) -> Self {
unsafe {
let sign_mask = _mm512_set_pd(-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0);
let other_conj = _mm512_xor_pd(other.0, sign_mask);
let a_re = _mm512_permute_pd(self.0, 0b0000_0000);
let a_im = _mm512_permute_pd(self.0, 0b1111_1111);
let b_flip = _mm512_permute_pd(other_conj, 0b0101_0101);
Self(_mm512_fmaddsub_pd(
a_re,
other_conj,
_mm512_mul_pd(a_im, b_flip),
))
}
}
}
impl SimdVector for Avx512F32 {
type Scalar = f32;
const LANES: usize = 16;
#[inline]
fn splat(value: f32) -> Self {
unsafe { Self(_mm512_set1_ps(value)) }
}
#[inline]
unsafe fn load_aligned(ptr: *const f32) -> Self {
unsafe { Self(_mm512_load_ps(ptr)) }
}
#[inline]
unsafe fn load_unaligned(ptr: *const f32) -> Self {
unsafe { Self(_mm512_loadu_ps(ptr)) }
}
#[inline]
unsafe fn store_aligned(self, ptr: *mut f32) {
unsafe { _mm512_store_ps(ptr, self.0) }
}
#[inline]
unsafe fn store_unaligned(self, ptr: *mut f32) {
unsafe { _mm512_storeu_ps(ptr, self.0) }
}
#[inline]
fn add(self, other: Self) -> Self {
unsafe { Self(_mm512_add_ps(self.0, other.0)) }
}
#[inline]
fn sub(self, other: Self) -> Self {
unsafe { Self(_mm512_sub_ps(self.0, other.0)) }
}
#[inline]
fn mul(self, other: Self) -> Self {
unsafe { Self(_mm512_mul_ps(self.0, other.0)) }
}
#[inline]
fn div(self, other: Self) -> Self {
unsafe { Self(_mm512_div_ps(self.0, other.0)) }
}
}
#[allow(dead_code)]
impl Avx512F32 {
#[inline]
pub fn fmadd(self, a: Self, b: Self) -> Self {
unsafe { Self(_mm512_fmadd_ps(self.0, a.0, b.0)) }
}
#[inline]
pub fn fmsub(self, a: Self, b: Self) -> Self {
unsafe { Self(_mm512_fmsub_ps(self.0, a.0, b.0)) }
}
#[inline]
pub fn fnmadd(self, a: Self, b: Self) -> Self {
unsafe { Self(_mm512_fnmadd_ps(self.0, a.0, b.0)) }
}
#[inline]
pub fn fnmsub(self, a: Self, b: Self) -> Self {
unsafe { Self(_mm512_fnmsub_ps(self.0, a.0, b.0)) }
}
#[inline]
pub fn negate(self) -> Self {
unsafe {
let sign_mask = _mm512_set1_ps(-0.0);
Self(_mm512_xor_ps(self.0, sign_mask))
}
}
#[inline]
pub fn low_256(self) -> super::avx::AvxF32 {
unsafe { super::avx::AvxF32(_mm512_castps512_ps256(self.0)) }
}
#[inline]
pub fn high_256(self) -> super::avx::AvxF32 {
unsafe { super::avx::AvxF32(_mm512_extractf32x8_ps(self.0, 1)) }
}
#[inline]
pub fn unpack_lo(self, other: Self) -> Self {
unsafe { Self(_mm512_unpacklo_ps(self.0, other.0)) }
}
#[inline]
pub fn unpack_hi(self, other: Self) -> Self {
unsafe { Self(_mm512_unpackhi_ps(self.0, other.0)) }
}
}
impl SimdComplex for Avx512F32 {
#[inline]
fn cmul(self, other: Self) -> Self {
unsafe {
let a_re = _mm512_moveldup_ps(self.0);
let a_im = _mm512_movehdup_ps(self.0);
let b_flip = _mm512_permute_ps(other.0, 0b10_11_00_01);
Self(_mm512_fmaddsub_ps(
a_re,
other.0,
_mm512_mul_ps(a_im, b_flip),
))
}
}
#[inline]
fn cmul_conj(self, other: Self) -> Self {
unsafe {
let sign_mask = _mm512_set_ps(
-0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0, 0.0, -0.0,
0.0,
);
let other_conj = _mm512_xor_ps(other.0, sign_mask);
let a_re = _mm512_moveldup_ps(self.0);
let a_im = _mm512_movehdup_ps(self.0);
let b_flip = _mm512_permute_ps(other_conj, 0b10_11_00_01);
Self(_mm512_fmaddsub_ps(
a_re,
other_conj,
_mm512_mul_ps(a_im, b_flip),
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_avx512_f64_basic() {
if !has_avx512f() {
return;
}
let a = Avx512F64::splat(2.0);
let b = Avx512F64::splat(3.0);
let c = a.add(b);
for i in 0..8 {
assert_eq!(c.extract(i), 5.0);
}
}
#[test]
fn test_avx512_f64_new() {
if !has_avx512f() {
return;
}
let v = Avx512F64::new(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
for i in 0..8 {
assert_eq!(v.extract(i), (i + 1) as f64);
}
}
#[test]
fn test_avx512_f64_fmadd() {
if !has_avx512f() {
return;
}
let a = Avx512F64::splat(2.0);
let b = Avx512F64::splat(3.0);
let c = Avx512F64::splat(4.0);
let result = a.fmadd(b, c);
for i in 0..8 {
assert_eq!(result.extract(i), 10.0);
}
}
#[test]
fn test_avx512_f64_load_store() {
if !has_avx512f() {
return;
}
let data = [1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let v = unsafe { Avx512F64::load_unaligned(data.as_ptr()) };
let mut out = [0.0_f64; 8];
unsafe { v.store_unaligned(out.as_mut_ptr()) };
assert_eq!(data, out);
}
#[test]
fn test_avx512_f64_cmul() {
if !has_avx512f() {
return;
}
let a = Avx512F64::new(3.0, 4.0, 1.0, 0.0, 1.0, 1.0, 2.0, 0.0);
let b = Avx512F64::new(1.0, 2.0, 1.0, 0.0, 1.0, -1.0, 1.0, 0.0);
let c = a.cmul(b);
let tol = 1e-10;
assert!((c.extract(0) - (-5.0)).abs() < tol);
assert!((c.extract(1) - 10.0).abs() < tol);
}
#[test]
fn test_avx512_f32_basic() {
if !has_avx512f() {
return;
}
let a = Avx512F32::splat(2.0);
let b = Avx512F32::splat(3.0);
let c = a.mul(b);
let mut out = [0.0_f32; 16];
unsafe { c.store_unaligned(out.as_mut_ptr()) };
for val in &out {
assert_eq!(*val, 6.0);
}
}
#[test]
fn test_avx512_f32_fmadd() {
if !has_avx512f() {
return;
}
let a = Avx512F32::splat(2.0);
let b = Avx512F32::splat(3.0);
let c = Avx512F32::splat(4.0);
let result = a.fmadd(b, c);
let mut out = [0.0_f32; 16];
unsafe { result.store_unaligned(out.as_mut_ptr()) };
for val in &out {
assert_eq!(*val, 10.0);
}
}
#[test]
fn test_avx512_f32_cmul() {
if !has_avx512f() {
return;
}
let mut input_a = [0.0_f32; 16];
let mut input_b = [0.0_f32; 16];
input_a[0] = 3.0;
input_a[1] = 4.0;
input_b[0] = 1.0;
input_b[1] = 2.0;
input_a[2] = 1.0;
input_a[3] = 0.0;
input_b[2] = 1.0;
input_b[3] = 0.0;
let a = unsafe { Avx512F32::load_unaligned(input_a.as_ptr()) };
let b = unsafe { Avx512F32::load_unaligned(input_b.as_ptr()) };
let c = a.cmul(b);
let mut out = [0.0_f32; 16];
unsafe { c.store_unaligned(out.as_mut_ptr()) };
let tol = 1e-5;
assert!((out[0] - (-5.0)).abs() < tol);
assert!((out[1] - 10.0).abs() < tol);
assert!((out[2] - 1.0).abs() < tol);
assert!((out[3] - 0.0).abs() < tol);
}
}