use super::super::{
AvxDescriptor, F32SimdVec, I32SimdVec, SimdDescriptor, SimdMask, U8SimdVec, U16SimdVec,
};
use crate::{Sse42Descriptor, U32SimdVec, impl_f32_array_interface};
use archmage::SimdToken;
use archmage::arcane;
use archmage::intrinsics::x86_64::*;
use std::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
Mul, MulAssign, Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
#[derive(Clone, Copy, Debug)]
pub struct Avx512Descriptor(archmage::X64V4Token);
impl Avx512Descriptor {
#[inline]
pub fn from_token(token: archmage::X64V4Token) -> Self {
Self(token)
}
#[inline(always)]
pub fn token(&self) -> archmage::X64V4Token {
self.0
}
#[inline]
pub fn as_avx(&self) -> AvxDescriptor {
AvxDescriptor::from_token(self.0.v3())
}
}
#[derive(Clone, Copy, Debug)]
#[repr(transparent)]
pub struct Bf16Table8Avx512(__m512);
impl SimdDescriptor for Avx512Descriptor {
type F32Vec = F32VecAvx512;
type I32Vec = I32VecAvx512;
type U32Vec = U32VecAvx512;
type U8Vec = U8VecAvx512;
type U16Vec = U16VecAvx512;
type Mask = MaskAvx512;
type Bf16Table8 = Bf16Table8Avx512;
type Descriptor256 = AvxDescriptor;
type Descriptor128 = Sse42Descriptor;
#[inline]
fn maybe_downgrade_256bit(self) -> Self::Descriptor256 {
self.as_avx()
}
#[inline]
fn maybe_downgrade_128bit(self) -> Self::Descriptor128 {
self.as_avx().as_sse42()
}
fn new() -> Option<Self> {
archmage::X64V4Token::summon().map(Self::from_token)
}
fn call<R>(self, f: impl FnOnce(Self) -> R) -> R {
#[arcane]
#[inline(always)]
fn impl_<R>(
_: archmage::X64V4Token,
d: Avx512Descriptor,
f: impl FnOnce(Avx512Descriptor) -> R,
) -> R {
f(d)
}
impl_(self.token(), self, f)
}
}
macro_rules! fn_avx {
(
$this:ident: $self_ty:ty,
fn $name:ident($($arg:ident: $ty:ty),* $(,)?) $(-> $ret:ty )? $body: block) => {
#[inline(always)]
fn $name(self: $self_ty, $($arg: $ty),*) $(-> $ret)? {
#[arcane]
#[inline(always)]
fn impl_(_t: archmage::X64V4Token, $this: $self_ty, $($arg: $ty),*) $(-> $ret)? $body
impl_(self.1.token(), self, $($arg),*)
}
};
}
#[derive(Clone, Copy, Debug)]
pub struct F32VecAvx512(__m512, Avx512Descriptor);
#[derive(Clone, Copy, Debug)]
pub struct MaskAvx512(__mmask16, Avx512Descriptor);
impl F32SimdVec for F32VecAvx512 {
type Descriptor = Avx512Descriptor;
const LEN: usize = 16;
#[inline(always)]
fn load(d: Self::Descriptor, mem: &[f32]) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, mem: &[f32]) -> __m512 {
_mm512_loadu_ps(mem.first_chunk::<16>().unwrap())
}
Self(impl_(d.token(), mem), d)
}
#[inline(always)]
fn store(&self, mem: &mut [f32]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512, mem: &mut [f32]) {
_mm512_storeu_ps(mem.first_chunk_mut::<16>().unwrap(), v)
}
impl_(self.1.token(), self.0, mem)
}
#[inline(always)]
fn store_interleaved_2(a: Self, b: Self, dest: &mut [f32]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, a: __m512, b: __m512, dest: &mut [f32]) {
assert!(dest.len() >= 2 * F32VecAvx512::LEN);
let lo = _mm512_unpacklo_ps(a, b);
let hi = _mm512_unpackhi_ps(a, b);
let idx_lo = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
let idx_hi =
_mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
let out0 = _mm512_permutex2var_ps(lo, idx_lo, hi);
let out1 = _mm512_permutex2var_ps(lo, idx_hi, hi);
_mm512_storeu_ps(dest[..16].first_chunk_mut::<16>().unwrap(), out0);
_mm512_storeu_ps(dest[16..32].first_chunk_mut::<16>().unwrap(), out1);
}
impl_(a.1.token(), a.0, b.0, dest)
}
#[inline(always)]
fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [f32]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, a: __m512, b: __m512, c: __m512, dest: &mut [f32]) {
assert!(dest.len() >= 3 * F32VecAvx512::LEN);
let idx_ab0 = _mm512_setr_epi32(0, 16, 0, 1, 17, 0, 2, 18, 0, 3, 19, 0, 4, 20, 0, 5);
let idx_c0 = _mm512_setr_epi32(0, 0, 0, 0, 0, 1, 0, 0, 2, 0, 0, 3, 0, 0, 4, 0);
let idx_ab1 = _mm512_setr_epi32(21, 0, 6, 22, 0, 7, 23, 0, 8, 24, 0, 9, 25, 0, 10, 26);
let idx_c1 = _mm512_setr_epi32(0, 5, 0, 0, 6, 0, 0, 7, 0, 0, 8, 0, 0, 9, 0, 0);
let idx_ab2 =
_mm512_setr_epi32(0, 11, 27, 0, 12, 28, 0, 13, 29, 0, 14, 30, 0, 15, 31, 0);
let idx_c2 = _mm512_setr_epi32(10, 0, 0, 11, 0, 0, 12, 0, 0, 13, 0, 0, 14, 0, 0, 15);
let out0 = _mm512_permutex2var_ps(a, idx_ab0, b);
let out0 = _mm512_mask_permutexvar_ps(out0, 0b0100100100100100, idx_c0, c);
let out1 = _mm512_permutex2var_ps(a, idx_ab1, b);
let out1 = _mm512_mask_permutexvar_ps(out1, 0b0010010010010010, idx_c1, c);
let out2 = _mm512_permutex2var_ps(a, idx_ab2, b);
let out2 = _mm512_mask_permutexvar_ps(out2, 0b1001001001001001, idx_c2, c);
_mm512_storeu_ps(dest[..16].first_chunk_mut::<16>().unwrap(), out0);
_mm512_storeu_ps(dest[16..32].first_chunk_mut::<16>().unwrap(), out1);
_mm512_storeu_ps(dest[32..48].first_chunk_mut::<16>().unwrap(), out2);
}
impl_(a.1.token(), a.0, b.0, c.0, dest)
}
#[inline(always)]
fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [f32]) {
#[arcane]
#[inline(always)]
fn impl_(
_: archmage::X64V4Token,
a: __m512,
b: __m512,
c: __m512,
d: __m512,
dest: &mut [f32],
) {
assert!(dest.len() >= 4 * F32VecAvx512::LEN);
let ab_lo = _mm512_unpacklo_ps(a, b);
let ab_hi = _mm512_unpackhi_ps(a, b);
let cd_lo = _mm512_unpacklo_ps(c, d);
let cd_hi = _mm512_unpackhi_ps(c, d);
let abcd_0 = _mm512_castpd_ps(_mm512_unpacklo_pd(
_mm512_castps_pd(ab_lo),
_mm512_castps_pd(cd_lo),
));
let abcd_1 = _mm512_castpd_ps(_mm512_unpackhi_pd(
_mm512_castps_pd(ab_lo),
_mm512_castps_pd(cd_lo),
));
let abcd_2 = _mm512_castpd_ps(_mm512_unpacklo_pd(
_mm512_castps_pd(ab_hi),
_mm512_castps_pd(cd_hi),
));
let abcd_3 = _mm512_castpd_ps(_mm512_unpackhi_pd(
_mm512_castps_pd(ab_hi),
_mm512_castps_pd(cd_hi),
));
let idx_even =
_mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27);
let idx_odd =
_mm512_setr_epi32(4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31);
let pair01_02 = _mm512_permutex2var_ps(abcd_0, idx_even, abcd_1);
let pair01_13 = _mm512_permutex2var_ps(abcd_0, idx_odd, abcd_1);
let pair23_02 = _mm512_permutex2var_ps(abcd_2, idx_even, abcd_3);
let pair23_13 = _mm512_permutex2var_ps(abcd_2, idx_odd, abcd_3);
let idx_0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23);
let idx_1 =
_mm512_setr_epi32(8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31);
let out0 = _mm512_permutex2var_ps(pair01_02, idx_0, pair23_02);
let out2 = _mm512_permutex2var_ps(pair01_02, idx_1, pair23_02);
let out1 = _mm512_permutex2var_ps(pair01_13, idx_0, pair23_13);
let out3 = _mm512_permutex2var_ps(pair01_13, idx_1, pair23_13);
_mm512_storeu_ps(dest[..16].first_chunk_mut::<16>().unwrap(), out0);
_mm512_storeu_ps(dest[16..32].first_chunk_mut::<16>().unwrap(), out1);
_mm512_storeu_ps(dest[32..48].first_chunk_mut::<16>().unwrap(), out2);
_mm512_storeu_ps(dest[48..64].first_chunk_mut::<16>().unwrap(), out3);
}
impl_(a.1.token(), a.0, b.0, c.0, d.0, dest)
}
#[inline(always)]
fn store_interleaved_8(
a: Self,
b: Self,
c: Self,
d: Self,
e: Self,
f: Self,
g: Self,
h: Self,
dest: &mut [f32],
) {
#[arcane]
#[inline(always)]
fn impl_(
_: archmage::X64V4Token,
a: __m512,
b: __m512,
c: __m512,
d: __m512,
e: __m512,
f: __m512,
g: __m512,
h: __m512,
dest: &mut [f32],
) {
assert!(dest.len() >= 8 * F32VecAvx512::LEN);
let ab_lo = _mm512_unpacklo_ps(a, b);
let ab_hi = _mm512_unpackhi_ps(a, b);
let cd_lo = _mm512_unpacklo_ps(c, d);
let cd_hi = _mm512_unpackhi_ps(c, d);
let ef_lo = _mm512_unpacklo_ps(e, f);
let ef_hi = _mm512_unpackhi_ps(e, f);
let gh_lo = _mm512_unpacklo_ps(g, h);
let gh_hi = _mm512_unpackhi_ps(g, h);
let abcd_0 = _mm512_castpd_ps(_mm512_unpacklo_pd(
_mm512_castps_pd(ab_lo),
_mm512_castps_pd(cd_lo),
));
let abcd_1 = _mm512_castpd_ps(_mm512_unpackhi_pd(
_mm512_castps_pd(ab_lo),
_mm512_castps_pd(cd_lo),
));
let abcd_2 = _mm512_castpd_ps(_mm512_unpacklo_pd(
_mm512_castps_pd(ab_hi),
_mm512_castps_pd(cd_hi),
));
let abcd_3 = _mm512_castpd_ps(_mm512_unpackhi_pd(
_mm512_castps_pd(ab_hi),
_mm512_castps_pd(cd_hi),
));
let efgh_0 = _mm512_castpd_ps(_mm512_unpacklo_pd(
_mm512_castps_pd(ef_lo),
_mm512_castps_pd(gh_lo),
));
let efgh_1 = _mm512_castpd_ps(_mm512_unpackhi_pd(
_mm512_castps_pd(ef_lo),
_mm512_castps_pd(gh_lo),
));
let efgh_2 = _mm512_castpd_ps(_mm512_unpacklo_pd(
_mm512_castps_pd(ef_hi),
_mm512_castps_pd(gh_hi),
));
let efgh_3 = _mm512_castpd_ps(_mm512_unpackhi_pd(
_mm512_castps_pd(ef_hi),
_mm512_castps_pd(gh_hi),
));
let idx_02 =
_mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27);
let idx_13 =
_mm512_setr_epi32(4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31);
let full_0_02 = _mm512_permutex2var_ps(abcd_0, idx_02, efgh_0);
let full_0_13 = _mm512_permutex2var_ps(abcd_0, idx_13, efgh_0);
let full_1_02 = _mm512_permutex2var_ps(abcd_1, idx_02, efgh_1);
let full_1_13 = _mm512_permutex2var_ps(abcd_1, idx_13, efgh_1);
let full_2_02 = _mm512_permutex2var_ps(abcd_2, idx_02, efgh_2);
let full_2_13 = _mm512_permutex2var_ps(abcd_2, idx_13, efgh_2);
let full_3_02 = _mm512_permutex2var_ps(abcd_3, idx_02, efgh_3);
let full_3_13 = _mm512_permutex2var_ps(abcd_3, idx_13, efgh_3);
let idx_lo = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23);
let idx_hi =
_mm512_setr_epi32(8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31);
let out0 = _mm512_permutex2var_ps(full_0_02, idx_lo, full_1_02);
let out1 = _mm512_permutex2var_ps(full_2_02, idx_lo, full_3_02);
let out2 = _mm512_permutex2var_ps(full_0_13, idx_lo, full_1_13);
let out3 = _mm512_permutex2var_ps(full_2_13, idx_lo, full_3_13);
let out4 = _mm512_permutex2var_ps(full_0_02, idx_hi, full_1_02);
let out5 = _mm512_permutex2var_ps(full_2_02, idx_hi, full_3_02);
let out6 = _mm512_permutex2var_ps(full_0_13, idx_hi, full_1_13);
let out7 = _mm512_permutex2var_ps(full_2_13, idx_hi, full_3_13);
_mm512_storeu_ps(dest[..16].first_chunk_mut::<16>().unwrap(), out0);
_mm512_storeu_ps(dest[16..32].first_chunk_mut::<16>().unwrap(), out1);
_mm512_storeu_ps(dest[32..48].first_chunk_mut::<16>().unwrap(), out2);
_mm512_storeu_ps(dest[48..64].first_chunk_mut::<16>().unwrap(), out3);
_mm512_storeu_ps(dest[64..80].first_chunk_mut::<16>().unwrap(), out4);
_mm512_storeu_ps(dest[80..96].first_chunk_mut::<16>().unwrap(), out5);
_mm512_storeu_ps(dest[96..112].first_chunk_mut::<16>().unwrap(), out6);
_mm512_storeu_ps(dest[112..128].first_chunk_mut::<16>().unwrap(), out7);
}
impl_(a.1.token(), a.0, b.0, c.0, d.0, e.0, f.0, g.0, h.0, dest)
}
#[inline(always)]
fn load_deinterleaved_2(d: Self::Descriptor, src: &[f32]) -> (Self, Self) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, src: &[f32]) -> (__m512, __m512) {
assert!(src.len() >= 2 * F32VecAvx512::LEN);
let in0 = _mm512_loadu_ps(src[..16].first_chunk::<16>().unwrap());
let in1 = _mm512_loadu_ps(src[16..32].first_chunk::<16>().unwrap());
let idx_a =
_mm512_setr_epi32(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30);
let idx_b =
_mm512_setr_epi32(1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31);
let a = _mm512_permutex2var_ps(in0, idx_a, in1);
let b = _mm512_permutex2var_ps(in0, idx_b, in1);
(a, b)
}
let (a, b) = impl_(d.token(), src);
(Self(a, d), Self(b, d))
}
#[inline(always)]
fn load_deinterleaved_3(d: Self::Descriptor, src: &[f32]) -> (Self, Self, Self) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, src: &[f32]) -> (__m512, __m512, __m512) {
assert!(src.len() >= 3 * F32VecAvx512::LEN);
let in0 = _mm512_loadu_ps(src[..16].first_chunk::<16>().unwrap());
let in1 = _mm512_loadu_ps(src[16..32].first_chunk::<16>().unwrap());
let in2 = _mm512_loadu_ps(src[32..48].first_chunk::<16>().unwrap());
let idx_a_01 = _mm512_setr_epi32(0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30, 0, 0, 0, 0, 0);
let idx_b_01 =
_mm512_setr_epi32(1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 0, 0, 0, 0, 0);
let idx_c_01 = _mm512_setr_epi32(2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 0, 0, 0, 0, 0, 0);
let idx_a_12 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 20, 23, 26, 29);
let idx_b_12 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 21, 24, 27, 30);
let idx_c_12 = _mm512_setr_epi32(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 19, 22, 25, 28, 31);
let a_01 = _mm512_permutex2var_ps(in0, idx_a_01, in1);
let a_12 = _mm512_permutex2var_ps(in1, idx_a_12, in2);
let a = _mm512_mask_blend_ps(0xF800, a_01, a_12);
let b_01 = _mm512_permutex2var_ps(in0, idx_b_01, in1);
let b_12 = _mm512_permutex2var_ps(in1, idx_b_12, in2);
let b = _mm512_mask_blend_ps(0xF800, b_01, b_12);
let c_01 = _mm512_permutex2var_ps(in0, idx_c_01, in1);
let c_12 = _mm512_permutex2var_ps(in1, idx_c_12, in2);
let c = _mm512_mask_blend_ps(0xFC00, c_01, c_12);
(a, b, c)
}
let (a, b, c) = impl_(d.token(), src);
(Self(a, d), Self(b, d), Self(c, d))
}
#[inline(always)]
fn load_deinterleaved_4(d: Self::Descriptor, src: &[f32]) -> (Self, Self, Self, Self) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, src: &[f32]) -> (__m512, __m512, __m512, __m512) {
assert!(src.len() >= 4 * F32VecAvx512::LEN);
let in0 = _mm512_loadu_ps(src[..16].first_chunk::<16>().unwrap());
let in1 = _mm512_loadu_ps(src[16..32].first_chunk::<16>().unwrap());
let in2 = _mm512_loadu_ps(src[32..48].first_chunk::<16>().unwrap());
let in3 = _mm512_loadu_ps(src[48..64].first_chunk::<16>().unwrap());
let idx_a = _mm512_setr_epi32(0, 4, 8, 12, 16, 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, 28);
let idx_b = _mm512_setr_epi32(1, 5, 9, 13, 17, 21, 25, 29, 1, 5, 9, 13, 17, 21, 25, 29);
let idx_c =
_mm512_setr_epi32(2, 6, 10, 14, 18, 22, 26, 30, 2, 6, 10, 14, 18, 22, 26, 30);
let idx_d =
_mm512_setr_epi32(3, 7, 11, 15, 19, 23, 27, 31, 3, 7, 11, 15, 19, 23, 27, 31);
let a01 = _mm512_permutex2var_ps(in0, idx_a, in1);
let a23 = _mm512_permutex2var_ps(in2, idx_a, in3);
let a = _mm512_mask_blend_ps(0xFF00, a01, a23);
let b01 = _mm512_permutex2var_ps(in0, idx_b, in1);
let b23 = _mm512_permutex2var_ps(in2, idx_b, in3);
let b = _mm512_mask_blend_ps(0xFF00, b01, b23);
let c01 = _mm512_permutex2var_ps(in0, idx_c, in1);
let c23 = _mm512_permutex2var_ps(in2, idx_c, in3);
let c = _mm512_mask_blend_ps(0xFF00, c01, c23);
let d01 = _mm512_permutex2var_ps(in0, idx_d, in1);
let d23 = _mm512_permutex2var_ps(in2, idx_d, in3);
let dv = _mm512_mask_blend_ps(0xFF00, d01, d23);
(a, b, c, dv)
}
let (a, b, c, dv) = impl_(d.token(), src);
(Self(a, d), Self(b, d), Self(c, d), Self(dv, d))
}
fn_avx!(this: F32VecAvx512, fn mul_add(mul: F32VecAvx512, add: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_fmadd_ps(this.0, mul.0, add.0), this.1)
});
fn_avx!(this: F32VecAvx512, fn neg_mul_add(mul: F32VecAvx512, add: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_fnmadd_ps(this.0, mul.0, add.0), this.1)
});
#[inline(always)]
fn splat(d: Self::Descriptor, v: f32) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: f32) -> __m512 {
_mm512_set1_ps(v)
}
Self(impl_(d.token(), v), d)
}
#[inline(always)]
fn zero(d: Self::Descriptor) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token) -> __m512 {
_mm512_setzero_ps()
}
Self(impl_(d.token()), d)
}
fn_avx!(this: F32VecAvx512, fn abs() -> F32VecAvx512 {
F32VecAvx512(_mm512_abs_ps(this.0), this.1)
});
fn_avx!(this: F32VecAvx512, fn floor() -> F32VecAvx512 {
F32VecAvx512(_mm512_roundscale_ps::<{ _MM_FROUND_FLOOR }>(this.0), this.1)
});
fn_avx!(this: F32VecAvx512, fn sqrt() -> F32VecAvx512 {
F32VecAvx512(_mm512_sqrt_ps(this.0), this.1)
});
fn_avx!(this: F32VecAvx512, fn neg() -> F32VecAvx512 {
F32VecAvx512(
_mm512_castsi512_ps(_mm512_xor_si512(
_mm512_set1_epi32(i32::MIN),
_mm512_castps_si512(this.0),
)),
this.1,
)
});
fn_avx!(this: F32VecAvx512, fn copysign(sign: F32VecAvx512) -> F32VecAvx512 {
let sign_mask = _mm512_set1_epi32(i32::MIN);
F32VecAvx512(
_mm512_castsi512_ps(_mm512_or_si512(
_mm512_andnot_si512(sign_mask, _mm512_castps_si512(this.0)),
_mm512_and_si512(sign_mask, _mm512_castps_si512(sign.0)),
)),
this.1,
)
});
fn_avx!(this: F32VecAvx512, fn max(other: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_max_ps(this.0, other.0), this.1)
});
fn_avx!(this: F32VecAvx512, fn min(other: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_min_ps(this.0, other.0), this.1)
});
fn_avx!(this: F32VecAvx512, fn gt(other: F32VecAvx512) -> MaskAvx512 {
MaskAvx512(_mm512_cmp_ps_mask::<{_CMP_GT_OQ}>(this.0, other.0), this.1)
});
fn_avx!(this: F32VecAvx512, fn as_i32() -> I32VecAvx512 {
I32VecAvx512(_mm512_cvtps_epi32(this.0), this.1)
});
fn_avx!(this: F32VecAvx512, fn bitcast_to_i32() -> I32VecAvx512 {
I32VecAvx512(_mm512_castps_si512(this.0), this.1)
});
#[inline(always)]
fn prepare_table_bf16_8(d: Avx512Descriptor, table: &[f32; 8]) -> Bf16Table8Avx512 {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, table: &[f32; 8]) -> __m512 {
let table_256 = _mm256_loadu_ps(table[..8].first_chunk::<8>().unwrap());
_mm512_castps256_ps512(table_256)
}
Bf16Table8Avx512(impl_(d.token(), table))
}
#[inline(always)]
fn table_lookup_bf16_8(
d: Avx512Descriptor,
table: Bf16Table8Avx512,
indices: I32VecAvx512,
) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, indices: __m512i, table: __m512) -> __m512 {
_mm512_permutexvar_ps(indices, table)
}
F32VecAvx512(impl_(d.token(), indices.0, table.0), d)
}
#[inline(always)]
fn round_store_u8(self, dest: &mut [u8]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512, dest: &mut [u8]) {
assert!(dest.len() >= F32VecAvx512::LEN);
let rounded = _mm512_roundscale_ps::<{ _MM_FROUND_TO_NEAREST_INT }>(v);
let i32s = _mm512_cvtps_epi32(rounded);
let u8s = _mm512_cvtusepi32_epi8(i32s);
_mm_storeu_si128(dest.first_chunk_mut::<16>().unwrap(), u8s);
}
impl_(self.1.token(), self.0, dest)
}
#[inline(always)]
fn round_store_u16(self, dest: &mut [u16]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512, dest: &mut [u16]) {
assert!(dest.len() >= F32VecAvx512::LEN);
let rounded = _mm512_roundscale_ps::<{ _MM_FROUND_TO_NEAREST_INT }>(v);
let i32s = _mm512_cvtps_epi32(rounded);
let u16s = _mm512_cvtusepi32_epi16(i32s);
_mm256_storeu_si256(dest.first_chunk_mut::<16>().unwrap(), u16s);
}
impl_(self.1.token(), self.0, dest)
}
impl_f32_array_interface!();
#[inline(always)]
fn load_f16_bits(d: Self::Descriptor, mem: &[u16]) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, d: Avx512Descriptor, mem: &[u16]) -> F32VecAvx512 {
assert!(mem.len() >= F32VecAvx512::LEN);
let bits = _mm256_loadu_si256(mem.first_chunk::<16>().unwrap());
F32VecAvx512(_mm512_cvtph_ps(bits), d)
}
impl_(d.token(), d, mem)
}
#[inline(always)]
fn store_f16_bits(self, dest: &mut [u16]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512, dest: &mut [u16]) {
assert!(dest.len() >= F32VecAvx512::LEN);
let bits = _mm512_cvtps_ph::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(v);
_mm256_storeu_si256(dest.first_chunk_mut::<16>().unwrap(), bits);
}
impl_(self.1.token(), self.0, dest)
}
#[inline(always)]
fn transpose_square(d: Self::Descriptor, data: &mut [Self::UnderlyingArray], stride: usize) {
#[arcane]
#[inline(always)]
fn impl_(
_: archmage::X64V4Token,
d: Avx512Descriptor,
data: &mut [[f32; 16]],
stride: usize,
) {
assert!(data.len() > stride * 15);
let r0 = F32VecAvx512::load_array(d, &data[0]).0;
let r1 = F32VecAvx512::load_array(d, &data[1 * stride]).0;
let r2 = F32VecAvx512::load_array(d, &data[2 * stride]).0;
let r3 = F32VecAvx512::load_array(d, &data[3 * stride]).0;
let r4 = F32VecAvx512::load_array(d, &data[4 * stride]).0;
let r5 = F32VecAvx512::load_array(d, &data[5 * stride]).0;
let r6 = F32VecAvx512::load_array(d, &data[6 * stride]).0;
let r7 = F32VecAvx512::load_array(d, &data[7 * stride]).0;
let r8 = F32VecAvx512::load_array(d, &data[8 * stride]).0;
let r9 = F32VecAvx512::load_array(d, &data[9 * stride]).0;
let r10 = F32VecAvx512::load_array(d, &data[10 * stride]).0;
let r11 = F32VecAvx512::load_array(d, &data[11 * stride]).0;
let r12 = F32VecAvx512::load_array(d, &data[12 * stride]).0;
let r13 = F32VecAvx512::load_array(d, &data[13 * stride]).0;
let r14 = F32VecAvx512::load_array(d, &data[14 * stride]).0;
let r15 = F32VecAvx512::load_array(d, &data[15 * stride]).0;
let t0 = _mm512_unpacklo_ps(r0, r1);
let t1 = _mm512_unpackhi_ps(r0, r1);
let t2 = _mm512_unpacklo_ps(r2, r3);
let t3 = _mm512_unpackhi_ps(r2, r3);
let t4 = _mm512_unpacklo_ps(r4, r5);
let t5 = _mm512_unpackhi_ps(r4, r5);
let t6 = _mm512_unpacklo_ps(r6, r7);
let t7 = _mm512_unpackhi_ps(r6, r7);
let t8 = _mm512_unpacklo_ps(r8, r9);
let t9 = _mm512_unpackhi_ps(r8, r9);
let t10 = _mm512_unpacklo_ps(r10, r11);
let t11 = _mm512_unpackhi_ps(r10, r11);
let t12 = _mm512_unpacklo_ps(r12, r13);
let t13 = _mm512_unpackhi_ps(r12, r13);
let t14 = _mm512_unpacklo_ps(r14, r15);
let t15 = _mm512_unpackhi_ps(r14, r15);
let t0 = _mm512_castps_pd(t0);
let t1 = _mm512_castps_pd(t1);
let t2 = _mm512_castps_pd(t2);
let t3 = _mm512_castps_pd(t3);
let t4 = _mm512_castps_pd(t4);
let t5 = _mm512_castps_pd(t5);
let t6 = _mm512_castps_pd(t6);
let t7 = _mm512_castps_pd(t7);
let t8 = _mm512_castps_pd(t8);
let t9 = _mm512_castps_pd(t9);
let t10 = _mm512_castps_pd(t10);
let t11 = _mm512_castps_pd(t11);
let t12 = _mm512_castps_pd(t12);
let t13 = _mm512_castps_pd(t13);
let t14 = _mm512_castps_pd(t14);
let t15 = _mm512_castps_pd(t15);
let s0 = _mm512_unpacklo_pd(t0, t2);
let s1 = _mm512_unpackhi_pd(t0, t2);
let s2 = _mm512_unpacklo_pd(t1, t3);
let s3 = _mm512_unpackhi_pd(t1, t3);
let s4 = _mm512_unpacklo_pd(t4, t6);
let s5 = _mm512_unpackhi_pd(t4, t6);
let s6 = _mm512_unpacklo_pd(t5, t7);
let s7 = _mm512_unpackhi_pd(t5, t7);
let s8 = _mm512_unpacklo_pd(t8, t10);
let s9 = _mm512_unpackhi_pd(t8, t10);
let s10 = _mm512_unpacklo_pd(t9, t11);
let s11 = _mm512_unpackhi_pd(t9, t11);
let s12 = _mm512_unpacklo_pd(t12, t14);
let s13 = _mm512_unpackhi_pd(t12, t14);
let s14 = _mm512_unpacklo_pd(t13, t15);
let s15 = _mm512_unpackhi_pd(t13, t15);
let idx_hi = _mm512_setr_epi64(0, 1, 8, 9, 4, 5, 12, 13);
let idx_lo = _mm512_add_epi64(idx_hi, _mm512_set1_epi64(2));
let c0 = _mm512_permutex2var_pd(s0, idx_hi, s4);
let c1 = _mm512_permutex2var_pd(s1, idx_hi, s5);
let c2 = _mm512_permutex2var_pd(s2, idx_hi, s6);
let c3 = _mm512_permutex2var_pd(s3, idx_hi, s7);
let c4 = _mm512_permutex2var_pd(s0, idx_lo, s4);
let c5 = _mm512_permutex2var_pd(s1, idx_lo, s5);
let c6 = _mm512_permutex2var_pd(s2, idx_lo, s6);
let c7 = _mm512_permutex2var_pd(s3, idx_lo, s7);
let c8 = _mm512_permutex2var_pd(s8, idx_hi, s12);
let c9 = _mm512_permutex2var_pd(s9, idx_hi, s13);
let c10 = _mm512_permutex2var_pd(s10, idx_hi, s14);
let c11 = _mm512_permutex2var_pd(s11, idx_hi, s15);
let c12 = _mm512_permutex2var_pd(s8, idx_lo, s12);
let c13 = _mm512_permutex2var_pd(s9, idx_lo, s13);
let c14 = _mm512_permutex2var_pd(s10, idx_lo, s14);
let c15 = _mm512_permutex2var_pd(s11, idx_lo, s15);
let idx_hi = _mm512_setr_epi64(0, 1, 2, 3, 8, 9, 10, 11);
let idx_lo = _mm512_add_epi64(idx_hi, _mm512_set1_epi64(4));
let o0 = _mm512_permutex2var_pd(c0, idx_hi, c8);
let o1 = _mm512_permutex2var_pd(c1, idx_hi, c9);
let o2 = _mm512_permutex2var_pd(c2, idx_hi, c10);
let o3 = _mm512_permutex2var_pd(c3, idx_hi, c11);
let o4 = _mm512_permutex2var_pd(c4, idx_hi, c12);
let o5 = _mm512_permutex2var_pd(c5, idx_hi, c13);
let o6 = _mm512_permutex2var_pd(c6, idx_hi, c14);
let o7 = _mm512_permutex2var_pd(c7, idx_hi, c15);
let o8 = _mm512_permutex2var_pd(c0, idx_lo, c8);
let o9 = _mm512_permutex2var_pd(c1, idx_lo, c9);
let o10 = _mm512_permutex2var_pd(c2, idx_lo, c10);
let o11 = _mm512_permutex2var_pd(c3, idx_lo, c11);
let o12 = _mm512_permutex2var_pd(c4, idx_lo, c12);
let o13 = _mm512_permutex2var_pd(c5, idx_lo, c13);
let o14 = _mm512_permutex2var_pd(c6, idx_lo, c14);
let o15 = _mm512_permutex2var_pd(c7, idx_lo, c15);
let o0 = _mm512_castpd_ps(o0);
let o1 = _mm512_castpd_ps(o1);
let o2 = _mm512_castpd_ps(o2);
let o3 = _mm512_castpd_ps(o3);
let o4 = _mm512_castpd_ps(o4);
let o5 = _mm512_castpd_ps(o5);
let o6 = _mm512_castpd_ps(o6);
let o7 = _mm512_castpd_ps(o7);
let o8 = _mm512_castpd_ps(o8);
let o9 = _mm512_castpd_ps(o9);
let o10 = _mm512_castpd_ps(o10);
let o11 = _mm512_castpd_ps(o11);
let o12 = _mm512_castpd_ps(o12);
let o13 = _mm512_castpd_ps(o13);
let o14 = _mm512_castpd_ps(o14);
let o15 = _mm512_castpd_ps(o15);
F32VecAvx512(o0, d).store_array(&mut data[0]);
F32VecAvx512(o1, d).store_array(&mut data[1 * stride]);
F32VecAvx512(o2, d).store_array(&mut data[2 * stride]);
F32VecAvx512(o3, d).store_array(&mut data[3 * stride]);
F32VecAvx512(o4, d).store_array(&mut data[4 * stride]);
F32VecAvx512(o5, d).store_array(&mut data[5 * stride]);
F32VecAvx512(o6, d).store_array(&mut data[6 * stride]);
F32VecAvx512(o7, d).store_array(&mut data[7 * stride]);
F32VecAvx512(o8, d).store_array(&mut data[8 * stride]);
F32VecAvx512(o9, d).store_array(&mut data[9 * stride]);
F32VecAvx512(o10, d).store_array(&mut data[10 * stride]);
F32VecAvx512(o11, d).store_array(&mut data[11 * stride]);
F32VecAvx512(o12, d).store_array(&mut data[12 * stride]);
F32VecAvx512(o13, d).store_array(&mut data[13 * stride]);
F32VecAvx512(o14, d).store_array(&mut data[14 * stride]);
F32VecAvx512(o15, d).store_array(&mut data[15 * stride]);
}
impl_(d.token(), d, data, stride)
}
}
impl Add<F32VecAvx512> for F32VecAvx512 {
type Output = F32VecAvx512;
fn_avx!(this: F32VecAvx512, fn add(rhs: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_add_ps(this.0, rhs.0), this.1)
});
}
impl Sub<F32VecAvx512> for F32VecAvx512 {
type Output = F32VecAvx512;
fn_avx!(this: F32VecAvx512, fn sub(rhs: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_sub_ps(this.0, rhs.0), this.1)
});
}
impl Mul<F32VecAvx512> for F32VecAvx512 {
type Output = F32VecAvx512;
fn_avx!(this: F32VecAvx512, fn mul(rhs: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_mul_ps(this.0, rhs.0), this.1)
});
}
impl Div<F32VecAvx512> for F32VecAvx512 {
type Output = F32VecAvx512;
fn_avx!(this: F32VecAvx512, fn div(rhs: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_div_ps(this.0, rhs.0), this.1)
});
}
impl AddAssign<F32VecAvx512> for F32VecAvx512 {
fn_avx!(this: &mut F32VecAvx512, fn add_assign(rhs: F32VecAvx512) {
this.0 = _mm512_add_ps(this.0, rhs.0)
});
}
impl SubAssign<F32VecAvx512> for F32VecAvx512 {
fn_avx!(this: &mut F32VecAvx512, fn sub_assign(rhs: F32VecAvx512) {
this.0 = _mm512_sub_ps(this.0, rhs.0)
});
}
impl MulAssign<F32VecAvx512> for F32VecAvx512 {
fn_avx!(this: &mut F32VecAvx512, fn mul_assign(rhs: F32VecAvx512) {
this.0 = _mm512_mul_ps(this.0, rhs.0)
});
}
impl DivAssign<F32VecAvx512> for F32VecAvx512 {
fn_avx!(this: &mut F32VecAvx512, fn div_assign(rhs: F32VecAvx512) {
this.0 = _mm512_div_ps(this.0, rhs.0)
});
}
#[derive(Clone, Copy, Debug)]
pub struct I32VecAvx512(__m512i, Avx512Descriptor);
impl I32SimdVec for I32VecAvx512 {
type Descriptor = Avx512Descriptor;
const LEN: usize = 16;
#[inline(always)]
fn load(d: Self::Descriptor, mem: &[i32]) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, mem: &[i32]) -> __m512i {
_mm512_loadu_epi32(mem.first_chunk::<16>().unwrap())
}
Self(impl_(d.token(), mem), d)
}
#[inline(always)]
fn store(&self, mem: &mut [i32]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512i, mem: &mut [i32]) {
_mm512_storeu_epi32(mem.first_chunk_mut::<16>().unwrap(), v)
}
impl_(self.1.token(), self.0, mem)
}
#[inline(always)]
fn splat(d: Self::Descriptor, v: i32) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: i32) -> __m512i {
_mm512_set1_epi32(v)
}
Self(impl_(d.token(), v), d)
}
fn_avx!(this: I32VecAvx512, fn as_f32() -> F32VecAvx512 {
F32VecAvx512(_mm512_cvtepi32_ps(this.0), this.1)
});
fn_avx!(this: I32VecAvx512, fn bitcast_to_f32() -> F32VecAvx512 {
F32VecAvx512(_mm512_castsi512_ps(this.0), this.1)
});
#[inline(always)]
fn bitcast_to_u32(self) -> U32VecAvx512 {
U32VecAvx512(self.0, self.1)
}
fn_avx!(this: I32VecAvx512, fn abs() -> I32VecAvx512 {
I32VecAvx512(_mm512_abs_epi32(this.0), this.1)
});
fn_avx!(this: I32VecAvx512, fn gt(rhs: I32VecAvx512) -> MaskAvx512 {
MaskAvx512(_mm512_cmpgt_epi32_mask(this.0, rhs.0), this.1)
});
fn_avx!(this: I32VecAvx512, fn lt_zero() -> MaskAvx512 {
I32VecAvx512(_mm512_setzero_epi32(), this.1).gt(this)
});
fn_avx!(this: I32VecAvx512, fn eq(rhs: I32VecAvx512) -> MaskAvx512 {
MaskAvx512(_mm512_cmpeq_epi32_mask(this.0, rhs.0), this.1)
});
fn_avx!(this: I32VecAvx512, fn eq_zero() -> MaskAvx512 {
I32VecAvx512(_mm512_setzero_epi32(), this.1).eq(this)
});
#[inline(always)]
fn shl<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
#[arcane]
#[inline(always)]
fn impl_<const AMOUNT_U: u32>(_: archmage::X64V4Token, v: __m512i) -> __m512i {
_mm512_slli_epi32::<AMOUNT_U>(v)
}
Self(impl_::<AMOUNT_U>(self.1.token(), self.0), self.1)
}
#[inline(always)]
fn shr<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
#[arcane]
#[inline(always)]
fn impl_<const AMOUNT_U: u32>(_: archmage::X64V4Token, v: __m512i) -> __m512i {
_mm512_srai_epi32::<AMOUNT_U>(v)
}
Self(impl_::<AMOUNT_U>(self.1.token(), self.0), self.1)
}
fn_avx!(this: I32VecAvx512, fn mul_wide_take_high(rhs: I32VecAvx512) -> I32VecAvx512 {
let l = _mm512_mul_epi32(this.0, rhs.0);
let h = _mm512_mul_epi32(_mm512_srli_epi64::<32>(this.0), _mm512_srli_epi64::<32>(rhs.0));
let idx = _mm512_setr_epi32(1, 17, 3, 19, 5, 21, 7, 23, 9, 25, 11, 27, 13, 29, 15, 31);
I32VecAvx512(_mm512_permutex2var_epi32(l, idx, h), this.1)
});
#[inline(always)]
fn store_u16(self, dest: &mut [u16]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512i, dest: &mut [u16]) {
assert!(dest.len() >= I32VecAvx512::LEN);
let tmp = _mm512_cvtepi32_epi16(v);
_mm256_storeu_si256(dest.first_chunk_mut::<16>().unwrap(), tmp);
}
impl_(self.1.token(), self.0, dest)
}
#[inline(always)]
fn store_u8(self, dest: &mut [u8]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512i, dest: &mut [u8]) {
assert!(dest.len() >= I32VecAvx512::LEN);
let mut tmp = [0i32; 16];
_mm512_storeu_si512(&mut tmp, v);
for i in 0..16 {
dest[i] = tmp[i] as u8;
}
}
impl_(self.1.token(), self.0, dest)
}
}
impl Add<I32VecAvx512> for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn add(rhs: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_add_epi32(this.0, rhs.0), this.1)
});
}
impl Sub<I32VecAvx512> for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn sub(rhs: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_sub_epi32(this.0, rhs.0), this.1)
});
}
impl Mul<I32VecAvx512> for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn mul(rhs: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_mullo_epi32(this.0, rhs.0), this.1)
});
}
impl Neg for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn neg() -> I32VecAvx512 {
I32VecAvx512(_mm512_setzero_epi32(), this.1) - this
});
}
impl Shl<I32VecAvx512> for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn shl(rhs: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_sllv_epi32(this.0, rhs.0), this.1)
});
}
impl Shr<I32VecAvx512> for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn shr(rhs: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_srav_epi32(this.0, rhs.0), this.1)
});
}
impl BitAnd<I32VecAvx512> for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn bitand(rhs: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_and_si512(this.0, rhs.0), this.1)
});
}
impl BitOr<I32VecAvx512> for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn bitor(rhs: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_or_si512(this.0, rhs.0), this.1)
});
}
impl BitXor<I32VecAvx512> for I32VecAvx512 {
type Output = I32VecAvx512;
fn_avx!(this: I32VecAvx512, fn bitxor(rhs: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_xor_si512(this.0, rhs.0), this.1)
});
}
impl AddAssign<I32VecAvx512> for I32VecAvx512 {
fn_avx!(this: &mut I32VecAvx512, fn add_assign(rhs: I32VecAvx512) {
this.0 = _mm512_add_epi32(this.0, rhs.0)
});
}
impl SubAssign<I32VecAvx512> for I32VecAvx512 {
fn_avx!(this: &mut I32VecAvx512, fn sub_assign(rhs: I32VecAvx512) {
this.0 = _mm512_sub_epi32(this.0, rhs.0)
});
}
impl MulAssign<I32VecAvx512> for I32VecAvx512 {
fn_avx!(this: &mut I32VecAvx512, fn mul_assign(rhs: I32VecAvx512) {
this.0 = _mm512_mullo_epi32(this.0, rhs.0)
});
}
impl ShlAssign<I32VecAvx512> for I32VecAvx512 {
fn_avx!(this: &mut I32VecAvx512, fn shl_assign(rhs: I32VecAvx512) {
this.0 = _mm512_sllv_epi32(this.0, rhs.0)
});
}
impl ShrAssign<I32VecAvx512> for I32VecAvx512 {
fn_avx!(this: &mut I32VecAvx512, fn shr_assign(rhs: I32VecAvx512) {
this.0 = _mm512_srav_epi32(this.0, rhs.0)
});
}
impl BitAndAssign<I32VecAvx512> for I32VecAvx512 {
fn_avx!(this: &mut I32VecAvx512, fn bitand_assign(rhs: I32VecAvx512) {
this.0 = _mm512_and_si512(this.0, rhs.0)
});
}
impl BitOrAssign<I32VecAvx512> for I32VecAvx512 {
fn_avx!(this: &mut I32VecAvx512, fn bitor_assign(rhs: I32VecAvx512) {
this.0 = _mm512_or_si512(this.0, rhs.0)
});
}
impl BitXorAssign<I32VecAvx512> for I32VecAvx512 {
fn_avx!(this: &mut I32VecAvx512, fn bitxor_assign(rhs: I32VecAvx512) {
this.0 = _mm512_xor_si512(this.0, rhs.0)
});
}
#[derive(Clone, Copy, Debug)]
pub struct U32VecAvx512(__m512i, Avx512Descriptor);
impl U32SimdVec for U32VecAvx512 {
type Descriptor = Avx512Descriptor;
const LEN: usize = 16;
#[inline(always)]
fn bitcast_to_i32(self) -> I32VecAvx512 {
I32VecAvx512(self.0, self.1)
}
#[inline(always)]
fn shr<const AMOUNT_U: u32, const AMOUNT_I: i32>(self) -> Self {
#[arcane]
#[inline(always)]
fn impl_<const AMOUNT_U: u32>(_: archmage::X64V4Token, v: __m512i) -> __m512i {
_mm512_srli_epi32::<AMOUNT_U>(v)
}
Self(impl_::<AMOUNT_U>(self.1.token(), self.0), self.1)
}
}
#[derive(Clone, Copy, Debug)]
pub struct U8VecAvx512(__m512i, Avx512Descriptor);
impl U8SimdVec for U8VecAvx512 {
type Descriptor = Avx512Descriptor;
const LEN: usize = 64;
#[inline(always)]
fn load(d: Self::Descriptor, mem: &[u8]) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, mem: &[u8]) -> __m512i {
_mm512_loadu_si512(mem.first_chunk::<64>().unwrap())
}
Self(impl_(d.token(), mem), d)
}
#[inline(always)]
fn splat(d: Self::Descriptor, v: u8) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: u8) -> __m512i {
_mm512_set1_epi8(v as i8)
}
Self(impl_(d.token(), v), d)
}
#[inline(always)]
fn store(&self, mem: &mut [u8]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512i, mem: &mut [u8]) {
_mm512_storeu_si512(mem.first_chunk_mut::<64>().unwrap(), v)
}
impl_(self.1.token(), self.0, mem)
}
#[inline(always)]
fn store_interleaved_2(a: Self, b: Self, dest: &mut [u8]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, a: __m512i, b: __m512i, dest: &mut [u8]) {
assert!(dest.len() >= 2 * U8VecAvx512::LEN);
let lo = _mm512_unpacklo_epi8(a, b);
let hi = _mm512_unpackhi_epi8(a, b);
let idx0 = _mm512_setr_epi64(0, 1, 8, 9, 2, 3, 10, 11);
let idx1 = _mm512_setr_epi64(4, 5, 12, 13, 6, 7, 14, 15);
let out0 = _mm512_permutex2var_epi64(lo, idx0, hi);
let out1 = _mm512_permutex2var_epi64(lo, idx1, hi);
_mm512_storeu_si512(dest[..64].first_chunk_mut::<64>().unwrap(), out0);
_mm512_storeu_si512(dest[64..128].first_chunk_mut::<64>().unwrap(), out1);
}
impl_(a.1.token(), a.0, b.0, dest)
}
#[inline(always)]
fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [u8]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, a: __m512i, b: __m512i, c: __m512i, dest: &mut [u8]) {
assert!(dest.len() >= 3 * U8VecAvx512::LEN);
let mask_a0 = _mm512_broadcast_i32x4(_mm_setr_epi8(
0, -1, -1, 1, -1, -1, 2, -1, -1, 3, -1, -1, 4, -1, -1, 5,
));
let mask_b0 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, 0, -1, -1, 1, -1, -1, 2, -1, -1, 3, -1, -1, 4, -1, -1,
));
let mask_c0 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, 0, -1, -1, 1, -1, -1, 2, -1, -1, 3, -1, -1, 4, -1,
));
let mask_a1 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, 6, -1, -1, 7, -1, -1, 8, -1, -1, 9, -1, -1, 10, -1,
));
let mask_b1 = _mm512_broadcast_i32x4(_mm_setr_epi8(
5, -1, -1, 6, -1, -1, 7, -1, -1, 8, -1, -1, 9, -1, -1, 10,
));
let mask_c1 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, 5, -1, -1, 6, -1, -1, 7, -1, -1, 8, -1, -1, 9, -1, -1,
));
let mask_a2 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, 11, -1, -1, 12, -1, -1, 13, -1, -1, 14, -1, -1, 15, -1, -1,
));
let mask_b2 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, 11, -1, -1, 12, -1, -1, 13, -1, -1, 14, -1, -1, 15, -1,
));
let mask_c2 = _mm512_broadcast_i32x4(_mm_setr_epi8(
10, -1, -1, 11, -1, -1, 12, -1, -1, 13, -1, -1, 14, -1, -1, 15,
));
let res0 = _mm512_or_si512(
_mm512_or_si512(
_mm512_shuffle_epi8(a, mask_a0),
_mm512_shuffle_epi8(b, mask_b0),
),
_mm512_shuffle_epi8(c, mask_c0),
);
let res1 = _mm512_or_si512(
_mm512_or_si512(
_mm512_shuffle_epi8(a, mask_a1),
_mm512_shuffle_epi8(b, mask_b1),
),
_mm512_shuffle_epi8(c, mask_c1),
);
let res2 = _mm512_or_si512(
_mm512_or_si512(
_mm512_shuffle_epi8(a, mask_a2),
_mm512_shuffle_epi8(b, mask_b2),
),
_mm512_shuffle_epi8(c, mask_c2),
);
let idx_a0 = _mm512_setr_epi64(0, 1, 8, 9, 2, 3, 0, 1);
let part_a0 = _mm512_permutex2var_epi64(res0, idx_a0, res1);
let idx_f0 = _mm512_setr_epi64(0, 1, 2, 3, 8, 9, 4, 5);
let final0 = _mm512_permutex2var_epi64(part_a0, idx_f0, res2);
let idx_a1 = _mm512_setr_epi64(2, 3, 10, 11, 4, 5, 0, 1);
let part_a1 = _mm512_permutex2var_epi64(res1, idx_a1, res2);
let idx_f1 = _mm512_setr_epi64(0, 1, 2, 3, 12, 13, 4, 5);
let final1 = _mm512_permutex2var_epi64(part_a1, idx_f1, res0);
let idx_a2 = _mm512_setr_epi64(4, 5, 14, 15, 6, 7, 0, 1);
let part_a2 = _mm512_permutex2var_epi64(res2, idx_a2, res0);
let idx_f2 = _mm512_setr_epi64(0, 1, 2, 3, 14, 15, 4, 5);
let final2 = _mm512_permutex2var_epi64(part_a2, idx_f2, res1);
_mm512_storeu_si512(dest[..64].first_chunk_mut::<64>().unwrap(), final0);
_mm512_storeu_si512(dest[64..128].first_chunk_mut::<64>().unwrap(), final1);
_mm512_storeu_si512(dest[128..192].first_chunk_mut::<64>().unwrap(), final2);
}
impl_(a.1.token(), a.0, b.0, c.0, dest)
}
#[inline(always)]
fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [u8]) {
#[arcane]
#[inline(always)]
fn impl_(
_: archmage::X64V4Token,
a: __m512i,
b: __m512i,
c: __m512i,
d: __m512i,
dest: &mut [u8],
) {
assert!(dest.len() >= 4 * U8VecAvx512::LEN);
let ab_lo = _mm512_unpacklo_epi8(a, b);
let ab_hi = _mm512_unpackhi_epi8(a, b);
let cd_lo = _mm512_unpacklo_epi8(c, d);
let cd_hi = _mm512_unpackhi_epi8(c, d);
let abcd_0 = _mm512_unpacklo_epi16(ab_lo, cd_lo);
let abcd_1 = _mm512_unpackhi_epi16(ab_lo, cd_lo);
let abcd_2 = _mm512_unpacklo_epi16(ab_hi, cd_hi);
let abcd_3 = _mm512_unpackhi_epi16(ab_hi, cd_hi);
let idx_even = _mm512_setr_epi64(0, 1, 8, 9, 2, 3, 10, 11);
let idx_odd = _mm512_setr_epi64(4, 5, 12, 13, 6, 7, 14, 15);
let pair01_02 = _mm512_permutex2var_epi64(abcd_0, idx_even, abcd_1);
let pair01_13 = _mm512_permutex2var_epi64(abcd_0, idx_odd, abcd_1);
let pair23_02 = _mm512_permutex2var_epi64(abcd_2, idx_even, abcd_3);
let pair23_13 = _mm512_permutex2var_epi64(abcd_2, idx_odd, abcd_3);
let idx_0 = _mm512_setr_epi64(0, 1, 2, 3, 8, 9, 10, 11);
let idx_1 = _mm512_setr_epi64(4, 5, 6, 7, 12, 13, 14, 15);
let out0 = _mm512_permutex2var_epi64(pair01_02, idx_0, pair23_02);
let out1 = _mm512_permutex2var_epi64(pair01_02, idx_1, pair23_02);
let out2 = _mm512_permutex2var_epi64(pair01_13, idx_0, pair23_13);
let out3 = _mm512_permutex2var_epi64(pair01_13, idx_1, pair23_13);
_mm512_storeu_si512(dest[..64].first_chunk_mut::<64>().unwrap(), out0);
_mm512_storeu_si512(dest[64..128].first_chunk_mut::<64>().unwrap(), out1);
_mm512_storeu_si512(dest[128..192].first_chunk_mut::<64>().unwrap(), out2);
_mm512_storeu_si512(dest[192..256].first_chunk_mut::<64>().unwrap(), out3);
}
impl_(a.1.token(), a.0, b.0, c.0, d.0, dest)
}
}
#[derive(Clone, Copy, Debug)]
pub struct U16VecAvx512(__m512i, Avx512Descriptor);
impl U16SimdVec for U16VecAvx512 {
type Descriptor = Avx512Descriptor;
const LEN: usize = 32;
#[inline(always)]
fn load(d: Self::Descriptor, mem: &[u16]) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, mem: &[u16]) -> __m512i {
_mm512_loadu_si512(mem.first_chunk::<32>().unwrap())
}
Self(impl_(d.token(), mem), d)
}
#[inline(always)]
fn splat(d: Self::Descriptor, v: u16) -> Self {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: u16) -> __m512i {
_mm512_set1_epi16(v as i16)
}
Self(impl_(d.token(), v), d)
}
#[inline(always)]
fn store(&self, mem: &mut [u16]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, v: __m512i, mem: &mut [u16]) {
_mm512_storeu_si512(mem.first_chunk_mut::<32>().unwrap(), v)
}
impl_(self.1.token(), self.0, mem)
}
#[inline(always)]
fn store_interleaved_2(a: Self, b: Self, dest: &mut [u16]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, a: __m512i, b: __m512i, dest: &mut [u16]) {
assert!(dest.len() >= 2 * U16VecAvx512::LEN);
let lo = _mm512_unpacklo_epi16(a, b);
let hi = _mm512_unpackhi_epi16(a, b);
let idx0 = _mm512_setr_epi64(0, 1, 8, 9, 2, 3, 10, 11);
let idx1 = _mm512_setr_epi64(4, 5, 12, 13, 6, 7, 14, 15);
let out0 = _mm512_permutex2var_epi64(lo, idx0, hi);
let out1 = _mm512_permutex2var_epi64(lo, idx1, hi);
_mm512_storeu_si512(dest[..32].first_chunk_mut::<32>().unwrap(), out0);
_mm512_storeu_si512(dest[32..64].first_chunk_mut::<32>().unwrap(), out1);
}
impl_(a.1.token(), a.0, b.0, dest)
}
#[inline(always)]
fn store_interleaved_3(a: Self, b: Self, c: Self, dest: &mut [u16]) {
#[arcane]
#[inline(always)]
fn impl_(_: archmage::X64V4Token, a: __m512i, b: __m512i, c: __m512i, dest: &mut [u16]) {
assert!(dest.len() >= 3 * U16VecAvx512::LEN);
let mask_a0 = _mm512_broadcast_i32x4(_mm_setr_epi8(
0, 1, -1, -1, -1, -1, 2, 3, -1, -1, -1, -1, 4, 5, -1, -1,
));
let mask_b0 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, 0, 1, -1, -1, -1, -1, 2, 3, -1, -1, -1, -1, 4, 5,
));
let mask_c0 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, -1, -1, 0, 1, -1, -1, -1, -1, 2, 3, -1, -1, -1, -1,
));
let mask_a1 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, 6, 7, -1, -1, -1, -1, 8, 9, -1, -1, -1, -1, 10, 11,
));
let mask_b1 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, -1, -1, 6, 7, -1, -1, -1, -1, 8, 9, -1, -1, -1, -1,
));
let mask_c1 = _mm512_broadcast_i32x4(_mm_setr_epi8(
4, 5, -1, -1, -1, -1, 6, 7, -1, -1, -1, -1, 8, 9, -1, -1,
));
let mask_a2 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, -1, -1, 12, 13, -1, -1, -1, -1, 14, 15, -1, -1, -1, -1,
));
let mask_b2 = _mm512_broadcast_i32x4(_mm_setr_epi8(
10, 11, -1, -1, -1, -1, 12, 13, -1, -1, -1, -1, 14, 15, -1, -1,
));
let mask_c2 = _mm512_broadcast_i32x4(_mm_setr_epi8(
-1, -1, 10, 11, -1, -1, -1, -1, 12, 13, -1, -1, -1, -1, 14, 15,
));
let res0 = _mm512_or_si512(
_mm512_or_si512(
_mm512_shuffle_epi8(a, mask_a0),
_mm512_shuffle_epi8(b, mask_b0),
),
_mm512_shuffle_epi8(c, mask_c0),
);
let res1 = _mm512_or_si512(
_mm512_or_si512(
_mm512_shuffle_epi8(a, mask_a1),
_mm512_shuffle_epi8(b, mask_b1),
),
_mm512_shuffle_epi8(c, mask_c1),
);
let res2 = _mm512_or_si512(
_mm512_or_si512(
_mm512_shuffle_epi8(a, mask_a2),
_mm512_shuffle_epi8(b, mask_b2),
),
_mm512_shuffle_epi8(c, mask_c2),
);
let idx_a0 = _mm512_setr_epi64(0, 1, 8, 9, 2, 3, 0, 1);
let part_a0 = _mm512_permutex2var_epi64(res0, idx_a0, res1);
let idx_f0 = _mm512_setr_epi64(0, 1, 2, 3, 8, 9, 4, 5);
let final0 = _mm512_permutex2var_epi64(part_a0, idx_f0, res2);
let idx_a1 = _mm512_setr_epi64(2, 3, 10, 11, 4, 5, 0, 1);
let part_a1 = _mm512_permutex2var_epi64(res1, idx_a1, res2);
let idx_f1 = _mm512_setr_epi64(0, 1, 2, 3, 12, 13, 4, 5);
let final1 = _mm512_permutex2var_epi64(part_a1, idx_f1, res0);
let idx_a2 = _mm512_setr_epi64(4, 5, 14, 15, 6, 7, 0, 1);
let part_a2 = _mm512_permutex2var_epi64(res2, idx_a2, res0);
let idx_f2 = _mm512_setr_epi64(0, 1, 2, 3, 14, 15, 4, 5);
let final2 = _mm512_permutex2var_epi64(part_a2, idx_f2, res1);
_mm512_storeu_si512(dest[..32].first_chunk_mut::<32>().unwrap(), final0);
_mm512_storeu_si512(dest[32..64].first_chunk_mut::<32>().unwrap(), final1);
_mm512_storeu_si512(dest[64..96].first_chunk_mut::<32>().unwrap(), final2);
}
impl_(a.1.token(), a.0, b.0, c.0, dest)
}
#[inline(always)]
fn store_interleaved_4(a: Self, b: Self, c: Self, d: Self, dest: &mut [u16]) {
#[arcane]
#[inline(always)]
fn impl_(
_: archmage::X64V4Token,
a: __m512i,
b: __m512i,
c: __m512i,
d: __m512i,
dest: &mut [u16],
) {
assert!(dest.len() >= 4 * U16VecAvx512::LEN);
let ab_lo = _mm512_unpacklo_epi16(a, b);
let ab_hi = _mm512_unpackhi_epi16(a, b);
let cd_lo = _mm512_unpacklo_epi16(c, d);
let cd_hi = _mm512_unpackhi_epi16(c, d);
let abcd_0 = _mm512_unpacklo_epi32(ab_lo, cd_lo);
let abcd_1 = _mm512_unpackhi_epi32(ab_lo, cd_lo);
let abcd_2 = _mm512_unpacklo_epi32(ab_hi, cd_hi);
let abcd_3 = _mm512_unpackhi_epi32(ab_hi, cd_hi);
let idx_even = _mm512_setr_epi64(0, 1, 8, 9, 2, 3, 10, 11);
let idx_odd = _mm512_setr_epi64(4, 5, 12, 13, 6, 7, 14, 15);
let pair01_02 = _mm512_permutex2var_epi64(abcd_0, idx_even, abcd_1);
let pair01_13 = _mm512_permutex2var_epi64(abcd_0, idx_odd, abcd_1);
let pair23_02 = _mm512_permutex2var_epi64(abcd_2, idx_even, abcd_3);
let pair23_13 = _mm512_permutex2var_epi64(abcd_2, idx_odd, abcd_3);
let idx_0 = _mm512_setr_epi64(0, 1, 2, 3, 8, 9, 10, 11);
let idx_1 = _mm512_setr_epi64(4, 5, 6, 7, 12, 13, 14, 15);
let out0 = _mm512_permutex2var_epi64(pair01_02, idx_0, pair23_02);
let out1 = _mm512_permutex2var_epi64(pair01_02, idx_1, pair23_02);
let out2 = _mm512_permutex2var_epi64(pair01_13, idx_0, pair23_13);
let out3 = _mm512_permutex2var_epi64(pair01_13, idx_1, pair23_13);
_mm512_storeu_si512(dest[..32].first_chunk_mut::<32>().unwrap(), out0);
_mm512_storeu_si512(dest[32..64].first_chunk_mut::<32>().unwrap(), out1);
_mm512_storeu_si512(dest[64..96].first_chunk_mut::<32>().unwrap(), out2);
_mm512_storeu_si512(dest[96..128].first_chunk_mut::<32>().unwrap(), out3);
}
impl_(a.1.token(), a.0, b.0, c.0, d.0, dest)
}
}
impl SimdMask for MaskAvx512 {
type Descriptor = Avx512Descriptor;
fn_avx!(this: MaskAvx512, fn if_then_else_f32(if_true: F32VecAvx512, if_false: F32VecAvx512) -> F32VecAvx512 {
F32VecAvx512(_mm512_mask_blend_ps(this.0, if_false.0, if_true.0), this.1)
});
fn_avx!(this: MaskAvx512, fn if_then_else_i32(if_true: I32VecAvx512, if_false: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_mask_blend_epi32(this.0, if_false.0, if_true.0), this.1)
});
fn_avx!(this: MaskAvx512, fn maskz_i32(v: I32VecAvx512) -> I32VecAvx512 {
I32VecAvx512(_mm512_mask_set1_epi32(v.0, this.0, 0), this.1)
});
fn_avx!(this: MaskAvx512, fn all() -> bool {
this.0 == 0b1111111111111111
});
fn_avx!(this: MaskAvx512, fn andnot(rhs: MaskAvx512) -> MaskAvx512 {
MaskAvx512((!this.0) & rhs.0, this.1)
});
}
impl BitAnd<MaskAvx512> for MaskAvx512 {
type Output = MaskAvx512;
fn_avx!(this: MaskAvx512, fn bitand(rhs: MaskAvx512) -> MaskAvx512 {
MaskAvx512(this.0 & rhs.0, this.1)
});
}
impl BitOr<MaskAvx512> for MaskAvx512 {
type Output = MaskAvx512;
fn_avx!(this: MaskAvx512, fn bitor(rhs: MaskAvx512) -> MaskAvx512 {
MaskAvx512(this.0 | rhs.0, this.1)
});
}