rten_simd/arch/x86_64/
avx512.rs

1use std::arch::x86_64::{
2    __m512, __m512i, __mmask16, __mmask32, __mmask64, _CMP_EQ_OQ, _CMP_GE_OQ, _CMP_GT_OQ,
3    _CMP_LE_OQ, _CMP_LT_OQ, _MM_CMPINT_EQ, _MM_CMPINT_NLE, _MM_CMPINT_NLT,
4    _MM_FROUND_TO_NEAREST_INT, _MM_HINT_ET0, _MM_HINT_T0, _mm_prefetch, _mm512_add_epi8,
5    _mm512_add_epi16, _mm512_add_epi32, _mm512_add_ps, _mm512_and_ps, _mm512_and_si512,
6    _mm512_andnot_ps, _mm512_andnot_si512, _mm512_castsi256_si512, _mm512_castsi512_si256,
7    _mm512_cmp_epi16_mask, _mm512_cmp_epi32_mask, _mm512_cmp_epu16_mask, _mm512_cmp_ps_mask,
8    _mm512_cmpeq_epi8_mask, _mm512_cmpeq_epu8_mask, _mm512_cmpge_epi8_mask, _mm512_cmpge_epu8_mask,
9    _mm512_cmpgt_epi8_mask, _mm512_cmpgt_epu8_mask, _mm512_cvtepi8_epi16, _mm512_cvtepi16_epi8,
10    _mm512_cvtepi16_epi32, _mm512_cvtepi32_ps, _mm512_cvtepu8_epi16, _mm512_cvtps_epi32,
11    _mm512_cvttps_epi32, _mm512_div_ps, _mm512_extracti64x4_epi64, _mm512_fmadd_ps,
12    _mm512_fnmadd_ps, _mm512_inserti64x4, _mm512_loadu_ps, _mm512_loadu_si512,
13    _mm512_mask_blend_epi8, _mm512_mask_blend_epi16, _mm512_mask_blend_epi32, _mm512_mask_blend_ps,
14    _mm512_mask_loadu_epi8, _mm512_mask_loadu_epi16, _mm512_mask_loadu_epi32, _mm512_mask_loadu_ps,
15    _mm512_mask_storeu_epi8, _mm512_mask_storeu_epi16, _mm512_mask_storeu_epi32,
16    _mm512_mask_storeu_ps, _mm512_max_ps, _mm512_min_ps, _mm512_mul_ps, _mm512_mullo_epi16,
17    _mm512_mullo_epi32, _mm512_or_ps, _mm512_or_si512, _mm512_packs_epi32, _mm512_packus_epi16,
18    _mm512_permutex2var_epi32, _mm512_permutexvar_epi64, _mm512_reduce_add_ps,
19    _mm512_roundscale_ps, _mm512_set1_epi8, _mm512_set1_epi16, _mm512_set1_epi32, _mm512_set1_ps,
20    _mm512_setr_epi32, _mm512_setr_epi64, _mm512_setzero_si512, _mm512_sllv_epi16,
21    _mm512_sllv_epi32, _mm512_srav_epi16, _mm512_srav_epi32, _mm512_srlv_epi16, _mm512_storeu_ps,
22    _mm512_storeu_si512, _mm512_sub_epi8, _mm512_sub_epi16, _mm512_sub_epi32, _mm512_sub_ps,
23    _mm512_unpackhi_epi8, _mm512_unpackhi_epi16, _mm512_unpacklo_epi8, _mm512_unpacklo_epi16,
24    _mm512_xor_ps, _mm512_xor_si512,
25};
26use std::mem::transmute;
27
28use super::super::{lanes, simd_type};
29use crate::ops::{
30    Concat, Extend, FloatOps, IntOps, Interleave, MaskOps, Narrow, NarrowSaturate, NumOps,
31    SignedIntOps, ToFloat,
32};
33use crate::{Isa, Mask, Simd};
34
35simd_type!(F32x16, __m512, f32, __mmask16, Avx512Isa);
36simd_type!(I32x16, __m512i, i32, __mmask16, Avx512Isa);
37simd_type!(I16x32, __m512i, i16, __mmask32, Avx512Isa);
38simd_type!(I8x64, __m512i, i8, __mmask64, Avx512Isa);
39simd_type!(U8x64, __m512i, u8, __mmask64, Avx512Isa);
40simd_type!(U16x32, __m512i, u16, __mmask32, Avx512Isa);
41simd_type!(U32x16, __m512i, u32, __mmask16, Avx512Isa);
42
43#[derive(Copy, Clone)]
44pub struct Avx512Isa {
45    _private: (),
46}
47
48impl Avx512Isa {
49    pub fn new() -> Option<Self> {
50        if crate::is_avx512_supported() {
51            Some(Avx512Isa { _private: () })
52        } else {
53            None
54        }
55    }
56}
57
58// Safety: AVX-512 is supported as `Avx512Isa::new` checks this.
59unsafe impl Isa for Avx512Isa {
60    type M32 = __mmask16;
61    type M16 = __mmask32;
62    type M8 = __mmask64;
63    type F32 = F32x16;
64    type I32 = I32x16;
65    type I16 = I16x32;
66    type I8 = I8x64;
67    type U8 = U8x64;
68    type U16 = U16x32;
69    type U32 = U32x16;
70    type Bits = I32x16;
71
72    fn f32(self) -> impl FloatOps<f32, Simd = Self::F32, Int = Self::I32> {
73        self
74    }
75
76    fn i32(
77        self,
78    ) -> impl SignedIntOps<i32, Simd = Self::I32>
79    + NarrowSaturate<i32, i16, Output = Self::I16>
80    + Concat<i32>
81    + ToFloat<i32, Output = Self::F32> {
82        self
83    }
84
85    fn i16(
86        self,
87    ) -> impl SignedIntOps<i16, Simd = Self::I16>
88    + NarrowSaturate<i16, u8, Output = Self::U8>
89    + Extend<i16, Output = Self::I32>
90    + Interleave<i16> {
91        self
92    }
93
94    fn i8(
95        self,
96    ) -> impl SignedIntOps<i8, Simd = Self::I8> + Extend<i8, Output = Self::I16> + Interleave<i8>
97    {
98        self
99    }
100
101    fn u8(
102        self,
103    ) -> impl IntOps<u8, Simd = Self::U8> + Extend<u8, Output = Self::U16> + Interleave<u8> {
104        self
105    }
106
107    fn u16(self) -> impl IntOps<u16, Simd = Self::U16> {
108        self
109    }
110
111    fn m32(self) -> impl MaskOps<Self::M32> {
112        self
113    }
114
115    fn m16(self) -> impl MaskOps<Self::M16> {
116        self
117    }
118
119    fn m8(self) -> impl MaskOps<Self::M8> {
120        self
121    }
122}
123
124macro_rules! simd_ops_common {
125    ($simd:ty, $mask:ty) => {
126        type Simd = $simd;
127
128        #[inline]
129        fn len(self) -> usize {
130            lanes::<$simd>()
131        }
132
133        #[inline]
134        fn first_n_mask(self, n: usize) -> $mask {
135            let mut mask = 0;
136            for i in 0..n {
137                mask |= 1 << i;
138            }
139            mask
140        }
141
142        #[inline]
143        fn prefetch(self, ptr: *const <$simd as Simd>::Elem) {
144            unsafe { _mm_prefetch(ptr as *const i8, _MM_HINT_T0) }
145        }
146
147        #[inline]
148        fn prefetch_write(self, ptr: *mut <$simd as Simd>::Elem) {
149            unsafe { _mm_prefetch(ptr as *const i8, _MM_HINT_ET0) }
150        }
151    };
152}
153
154macro_rules! simd_int_ops_common {
155    ($simd:ty) => {
156        #[inline]
157        fn and(self, x: $simd, y: $simd) -> $simd {
158            unsafe { _mm512_and_si512(x.0, y.0) }.into()
159        }
160
161        #[inline]
162        fn or(self, x: $simd, y: $simd) -> $simd {
163            unsafe { _mm512_or_si512(x.0, y.0) }.into()
164        }
165
166        #[inline]
167        fn xor(self, x: $simd, y: $simd) -> $simd {
168            unsafe { _mm512_xor_si512(x.0, y.0) }.into()
169        }
170
171        #[inline]
172        fn not(self, x: $simd) -> $simd {
173            unsafe { _mm512_andnot_si512(x.0, _mm512_set1_epi8(-1)) }.into()
174        }
175    };
176}
177
178unsafe impl NumOps<f32> for Avx512Isa {
179    simd_ops_common!(F32x16, __mmask16);
180
181    #[inline]
182    fn add(self, x: F32x16, y: F32x16) -> F32x16 {
183        unsafe { _mm512_add_ps(x.0, y.0) }.into()
184    }
185
186    #[inline]
187    fn sub(self, x: F32x16, y: F32x16) -> F32x16 {
188        unsafe { _mm512_sub_ps(x.0, y.0) }.into()
189    }
190
191    #[inline]
192    fn mul(self, x: F32x16, y: F32x16) -> F32x16 {
193        unsafe { _mm512_mul_ps(x.0, y.0) }.into()
194    }
195
196    #[inline]
197    fn mul_add(self, a: F32x16, b: F32x16, c: F32x16) -> F32x16 {
198        unsafe { _mm512_fmadd_ps(a.0, b.0, c.0) }.into()
199    }
200
201    #[inline]
202    fn lt(self, x: F32x16, y: F32x16) -> __mmask16 {
203        unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_LT_OQ) }
204    }
205
206    #[inline]
207    fn le(self, x: F32x16, y: F32x16) -> __mmask16 {
208        unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_LE_OQ) }
209    }
210
211    #[inline]
212    fn eq(self, x: F32x16, y: F32x16) -> __mmask16 {
213        unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_EQ_OQ) }
214    }
215
216    #[inline]
217    fn ge(self, x: F32x16, y: F32x16) -> __mmask16 {
218        unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_GE_OQ) }
219    }
220
221    #[inline]
222    fn gt(self, x: F32x16, y: F32x16) -> __mmask16 {
223        unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_GT_OQ) }
224    }
225
226    #[inline]
227    fn min(self, x: F32x16, y: F32x16) -> F32x16 {
228        unsafe { _mm512_min_ps(x.0, y.0) }.into()
229    }
230
231    #[inline]
232    fn max(self, x: F32x16, y: F32x16) -> F32x16 {
233        unsafe { _mm512_max_ps(x.0, y.0) }.into()
234    }
235
236    #[inline]
237    fn and(self, x: F32x16, y: F32x16) -> F32x16 {
238        unsafe { _mm512_and_ps(x.0, y.0) }.into()
239    }
240
241    #[inline]
242    fn not(self, x: F32x16) -> F32x16 {
243        let all_ones: F32x16 = self.splat(f32::from_bits(0xFFFFFFFF));
244        unsafe { _mm512_andnot_ps(x.0, all_ones.0) }.into()
245    }
246
247    #[inline]
248    fn or(self, x: F32x16, y: F32x16) -> F32x16 {
249        unsafe { _mm512_or_ps(x.0, y.0) }.into()
250    }
251
252    #[inline]
253    fn xor(self, x: F32x16, y: F32x16) -> F32x16 {
254        unsafe { _mm512_xor_ps(x.0, y.0) }.into()
255    }
256
257    #[inline]
258    fn splat(self, x: f32) -> F32x16 {
259        unsafe { _mm512_set1_ps(x) }.into()
260    }
261
262    #[inline]
263    unsafe fn load_ptr(self, ptr: *const f32) -> F32x16 {
264        unsafe { _mm512_loadu_ps(ptr) }.into()
265    }
266
267    #[inline]
268    fn select(self, x: F32x16, y: F32x16, mask: <F32x16 as Simd>::Mask) -> F32x16 {
269        unsafe { _mm512_mask_blend_ps(mask, y.0, x.0) }.into()
270    }
271
272    #[inline]
273    unsafe fn load_ptr_mask(self, ptr: *const f32, mask: __mmask16) -> F32x16 {
274        unsafe { _mm512_mask_loadu_ps(_mm512_set1_ps(0.), mask, ptr) }.into()
275    }
276
277    #[inline]
278    unsafe fn store_ptr_mask(self, x: F32x16, ptr: *mut f32, mask: __mmask16) {
279        unsafe { _mm512_mask_storeu_ps(ptr, mask, x.0) }
280    }
281
282    #[inline]
283    unsafe fn store_ptr(self, x: F32x16, ptr: *mut f32) {
284        unsafe { _mm512_storeu_ps(ptr, x.0) }
285    }
286
287    #[inline]
288    fn sum(self, x: F32x16) -> f32 {
289        unsafe { _mm512_reduce_add_ps(x.0) }
290    }
291}
292
293impl FloatOps<f32> for Avx512Isa {
294    type Int = <Self as Isa>::I32;
295
296    #[inline]
297    fn div(self, x: F32x16, y: F32x16) -> F32x16 {
298        unsafe { _mm512_div_ps(x.0, y.0) }.into()
299    }
300
301    #[inline]
302    fn abs(self, x: F32x16) -> F32x16 {
303        unsafe { _mm512_andnot_ps(_mm512_set1_ps(-0.0), x.0) }.into()
304    }
305
306    #[inline]
307    fn neg(self, x: F32x16) -> F32x16 {
308        unsafe { _mm512_xor_ps(x.0, _mm512_set1_ps(-0.0)) }.into()
309    }
310
311    #[inline]
312    fn mul_sub_from(self, a: F32x16, b: F32x16, c: F32x16) -> F32x16 {
313        unsafe { _mm512_fnmadd_ps(a.0, b.0, c.0) }.into()
314    }
315
316    #[inline]
317    fn round_ties_even(self, x: F32x16) -> F32x16 {
318        unsafe { _mm512_roundscale_ps(x.0, _MM_FROUND_TO_NEAREST_INT) }.into()
319    }
320
321    #[inline]
322    fn to_int_trunc(self, x: F32x16) -> Self::Int {
323        unsafe { _mm512_cvttps_epi32(x.0) }.into()
324    }
325
326    #[inline]
327    fn to_int_round(self, x: F32x16) -> Self::Int {
328        unsafe { _mm512_cvtps_epi32(x.0) }.into()
329    }
330}
331
332unsafe impl NumOps<i32> for Avx512Isa {
333    simd_ops_common!(I32x16, __mmask16);
334    simd_int_ops_common!(I32x16);
335
336    #[inline]
337    fn add(self, x: I32x16, y: I32x16) -> I32x16 {
338        unsafe { _mm512_add_epi32(x.0, y.0) }.into()
339    }
340
341    #[inline]
342    fn sub(self, x: I32x16, y: I32x16) -> I32x16 {
343        unsafe { _mm512_sub_epi32(x.0, y.0) }.into()
344    }
345
346    #[inline]
347    fn mul(self, x: I32x16, y: I32x16) -> I32x16 {
348        unsafe { _mm512_mullo_epi32(x.0, y.0) }.into()
349    }
350
351    #[inline]
352    fn splat(self, x: i32) -> I32x16 {
353        unsafe { _mm512_set1_epi32(x) }.into()
354    }
355
356    #[inline]
357    fn eq(self, x: I32x16, y: I32x16) -> __mmask16 {
358        unsafe { _mm512_cmp_epi32_mask(x.0, y.0, _MM_CMPINT_EQ) }
359    }
360
361    #[inline]
362    fn ge(self, x: I32x16, y: I32x16) -> __mmask16 {
363        unsafe { _mm512_cmp_epi32_mask(x.0, y.0, _MM_CMPINT_NLT) }
364    }
365
366    #[inline]
367    fn gt(self, x: I32x16, y: I32x16) -> __mmask16 {
368        unsafe { _mm512_cmp_epi32_mask(x.0, y.0, _MM_CMPINT_NLE) }
369    }
370
371    #[inline]
372    unsafe fn load_ptr(self, ptr: *const i32) -> I32x16 {
373        unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
374    }
375
376    #[inline]
377    fn select(self, x: I32x16, y: I32x16, mask: <I32x16 as Simd>::Mask) -> I32x16 {
378        unsafe { _mm512_mask_blend_epi32(mask, y.0, x.0) }.into()
379    }
380
381    #[inline]
382    unsafe fn store_ptr(self, x: I32x16, ptr: *mut i32) {
383        unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
384    }
385
386    #[inline]
387    unsafe fn load_ptr_mask(self, ptr: *const i32, mask: __mmask16) -> I32x16 {
388        unsafe { _mm512_mask_loadu_epi32(_mm512_set1_epi32(0), mask, ptr) }.into()
389    }
390
391    #[inline]
392    unsafe fn store_ptr_mask(self, x: I32x16, ptr: *mut i32, mask: __mmask16) {
393        unsafe { _mm512_mask_storeu_epi32(ptr, mask, x.0) }
394    }
395}
396
397impl IntOps<i32> for Avx512Isa {
398    #[inline]
399    fn shift_left<const SHIFT: i32>(self, x: I32x16) -> I32x16 {
400        let count: I32x16 = self.splat(SHIFT);
401        unsafe { _mm512_sllv_epi32(x.0, count.0) }.into()
402    }
403
404    #[inline]
405    fn shift_right<const SHIFT: i32>(self, x: I32x16) -> I32x16 {
406        let count: I32x16 = self.splat(SHIFT);
407        unsafe { _mm512_srav_epi32(x.0, count.0) }.into()
408    }
409}
410
411impl SignedIntOps<i32> for Avx512Isa {
412    #[inline]
413    fn neg(self, x: I32x16) -> I32x16 {
414        unsafe { _mm512_sub_epi32(_mm512_setzero_si512(), x.0) }.into()
415    }
416}
417
418impl NarrowSaturate<i32, i16> for Avx512Isa {
419    type Output = I16x32;
420
421    #[inline]
422    fn narrow_saturate(self, low: I32x16, high: I32x16) -> I16x32 {
423        unsafe {
424            // _mm512_packs_epi32 treats each input as 4 128-bit lanes and
425            // interleaves narrowed 64-bit blocks from each input. Shuffle the
426            // output to get narrowed lanes from `low` followed by lanes from
427            // `high`.
428            let packed = _mm512_packs_epi32(low.0, high.0);
429            let permutation = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
430            _mm512_permutexvar_epi64(permutation, packed)
431        }
432        .into()
433    }
434}
435
436impl Concat<i32> for Avx512Isa {
437    #[inline]
438    fn concat_low(self, a: I32x16, b: I32x16) -> I32x16 {
439        unsafe {
440            let a_lo = _mm512_castsi512_si256(a.0);
441            let b_lo = _mm512_castsi512_si256(b.0);
442            _mm512_inserti64x4(_mm512_castsi256_si512(a_lo), b_lo, 1)
443        }
444        .into()
445    }
446
447    #[inline]
448    fn concat_high(self, a: I32x16, b: I32x16) -> I32x16 {
449        unsafe {
450            let a_hi = _mm512_extracti64x4_epi64(a.0, 1);
451            let b_hi = _mm512_extracti64x4_epi64(b.0, 1);
452            _mm512_inserti64x4(_mm512_castsi256_si512(a_hi), b_hi, 1)
453        }
454        .into()
455    }
456}
457
458impl ToFloat<i32> for Avx512Isa {
459    type Output = F32x16;
460
461    #[inline]
462    fn to_float(self, x: I32x16) -> F32x16 {
463        unsafe { _mm512_cvtepi32_ps(x.0) }.into()
464    }
465}
466
467unsafe impl NumOps<i16> for Avx512Isa {
468    simd_ops_common!(I16x32, __mmask32);
469    simd_int_ops_common!(I16x32);
470
471    #[inline]
472    fn add(self, x: I16x32, y: I16x32) -> I16x32 {
473        unsafe { _mm512_add_epi16(x.0, y.0) }.into()
474    }
475
476    #[inline]
477    fn sub(self, x: I16x32, y: I16x32) -> I16x32 {
478        unsafe { _mm512_sub_epi16(x.0, y.0) }.into()
479    }
480
481    #[inline]
482    fn mul(self, x: I16x32, y: I16x32) -> I16x32 {
483        unsafe { _mm512_mullo_epi16(x.0, y.0) }.into()
484    }
485
486    #[inline]
487    fn splat(self, x: i16) -> I16x32 {
488        unsafe { _mm512_set1_epi16(x) }.into()
489    }
490
491    #[inline]
492    fn eq(self, x: I16x32, y: I16x32) -> __mmask32 {
493        unsafe { _mm512_cmp_epi16_mask(x.0, y.0, _MM_CMPINT_EQ) }
494    }
495
496    #[inline]
497    fn ge(self, x: I16x32, y: I16x32) -> __mmask32 {
498        unsafe { _mm512_cmp_epi16_mask(x.0, y.0, _MM_CMPINT_NLT) }
499    }
500
501    #[inline]
502    fn gt(self, x: I16x32, y: I16x32) -> __mmask32 {
503        unsafe { _mm512_cmp_epi16_mask(x.0, y.0, _MM_CMPINT_NLE) }
504    }
505
506    #[inline]
507    unsafe fn load_ptr(self, ptr: *const i16) -> I16x32 {
508        unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
509    }
510
511    #[inline]
512    fn select(self, x: I16x32, y: I16x32, mask: <I16x32 as Simd>::Mask) -> I16x32 {
513        unsafe { _mm512_mask_blend_epi16(mask, y.0, x.0) }.into()
514    }
515
516    #[inline]
517    unsafe fn store_ptr(self, x: I16x32, ptr: *mut i16) {
518        unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
519    }
520
521    #[inline]
522    unsafe fn load_ptr_mask(self, ptr: *const i16, mask: __mmask32) -> I16x32 {
523        unsafe { _mm512_mask_loadu_epi16(_mm512_set1_epi16(0), mask, ptr) }.into()
524    }
525
526    #[inline]
527    unsafe fn store_ptr_mask(self, x: I16x32, ptr: *mut i16, mask: __mmask32) {
528        unsafe { _mm512_mask_storeu_epi16(ptr, mask, x.0) }
529    }
530}
531
532impl IntOps<i16> for Avx512Isa {
533    #[inline]
534    fn shift_left<const SHIFT: i32>(self, x: I16x32) -> I16x32 {
535        let count: I16x32 = self.splat(SHIFT as i16);
536        unsafe { _mm512_sllv_epi16(x.0, count.0) }.into()
537    }
538
539    #[inline]
540    fn shift_right<const SHIFT: i32>(self, x: I16x32) -> I16x32 {
541        let count: I16x32 = self.splat(SHIFT as i16);
542        unsafe { _mm512_srav_epi16(x.0, count.0) }.into()
543    }
544}
545
546impl SignedIntOps<i16> for Avx512Isa {
547    #[inline]
548    fn neg(self, x: I16x32) -> I16x32 {
549        unsafe { _mm512_sub_epi16(_mm512_setzero_si512(), x.0) }.into()
550    }
551}
552
553impl NarrowSaturate<i16, u8> for Avx512Isa {
554    type Output = U8x64;
555
556    #[inline]
557    fn narrow_saturate(self, low: I16x32, high: I16x32) -> U8x64 {
558        unsafe {
559            // _mm512_packus_epi16 treats each input as 4 128-bit lanes and
560            // interleaves narrowed 64-bit blocks from each input. Shuffle the
561            // output to get narrowed lanes from `low` followed by lanes from
562            // `high`.
563            let packed = _mm512_packus_epi16(low.0, high.0);
564            let permutation = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
565            _mm512_permutexvar_epi64(permutation, packed)
566        }
567        .into()
568    }
569}
570
571impl Interleave<i16> for Avx512Isa {
572    #[inline]
573    fn interleave_low(self, a: I16x32, b: I16x32) -> I16x32 {
574        unsafe {
575            // AB{N} = Interleaved Nth 64-bit block.
576            let lo = _mm512_unpacklo_epi16(a.0, b.0); // AB0 AB2 AB4 AB6
577            let hi = _mm512_unpackhi_epi16(a.0, b.0); // AB1 AB3 AB5 AB7
578            let idx = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
579            _mm512_permutex2var_epi32(lo, idx, hi) // AB0 AB1 AB2 AB3
580        }
581        .into()
582    }
583
584    #[inline]
585    fn interleave_high(self, a: I16x32, b: I16x32) -> I16x32 {
586        unsafe {
587            // AB{N} = Interleaved Nth 64-bit block.
588            let lo = _mm512_unpacklo_epi16(a.0, b.0); // AB0 AB2 AB4 AB6
589            let hi = _mm512_unpackhi_epi16(a.0, b.0); // AB1 AB3 AB5 AB7
590            let idx =
591                _mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
592            _mm512_permutex2var_epi32(lo, idx, hi) // AB4 AB5 AB6 AB7
593        }
594        .into()
595    }
596}
597
598unsafe impl NumOps<i8> for Avx512Isa {
599    simd_ops_common!(I8x64, __mmask64);
600    simd_int_ops_common!(I8x64);
601
602    #[inline]
603    fn add(self, x: I8x64, y: I8x64) -> I8x64 {
604        unsafe { _mm512_add_epi8(x.0, y.0) }.into()
605    }
606
607    #[inline]
608    fn sub(self, x: I8x64, y: I8x64) -> I8x64 {
609        unsafe { _mm512_sub_epi8(x.0, y.0) }.into()
610    }
611
612    #[inline]
613    fn mul(self, x: I8x64, y: I8x64) -> I8x64 {
614        let (x_lo, x_hi) = Extend::<i8>::extend(self, x);
615        let (y_lo, y_hi) = Extend::<i8>::extend(self, y);
616
617        let i16_ops = self.i16();
618        let prod_lo = i16_ops.mul(x_lo, y_lo);
619        let prod_hi = i16_ops.mul(x_hi, y_hi);
620
621        self.narrow_truncate(prod_lo, prod_hi)
622    }
623
624    #[inline]
625    fn splat(self, x: i8) -> I8x64 {
626        unsafe { _mm512_set1_epi8(x) }.into()
627    }
628
629    #[inline]
630    fn eq(self, x: I8x64, y: I8x64) -> __mmask64 {
631        unsafe { _mm512_cmpeq_epi8_mask(x.0, y.0) }
632    }
633
634    #[inline]
635    fn ge(self, x: I8x64, y: I8x64) -> __mmask64 {
636        unsafe { _mm512_cmpge_epi8_mask(x.0, y.0) }
637    }
638
639    #[inline]
640    fn gt(self, x: I8x64, y: I8x64) -> __mmask64 {
641        unsafe { _mm512_cmpgt_epi8_mask(x.0, y.0) }
642    }
643
644    #[inline]
645    unsafe fn load_ptr(self, ptr: *const i8) -> I8x64 {
646        unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
647    }
648
649    #[inline]
650    fn select(self, x: I8x64, y: I8x64, mask: <I8x64 as Simd>::Mask) -> I8x64 {
651        unsafe { _mm512_mask_blend_epi8(mask, y.0, x.0) }.into()
652    }
653
654    #[inline]
655    unsafe fn store_ptr(self, x: I8x64, ptr: *mut i8) {
656        unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
657    }
658
659    #[inline]
660    unsafe fn load_ptr_mask(self, ptr: *const i8, mask: __mmask64) -> I8x64 {
661        unsafe { _mm512_mask_loadu_epi8(_mm512_set1_epi8(0), mask, ptr) }.into()
662    }
663
664    #[inline]
665    unsafe fn store_ptr_mask(self, x: I8x64, ptr: *mut i8, mask: __mmask64) {
666        unsafe { _mm512_mask_storeu_epi8(ptr, mask, x.0) }
667    }
668}
669
670impl IntOps<i8> for Avx512Isa {
671    #[inline]
672    fn shift_left<const SHIFT: i32>(self, x: I8x64) -> I8x64 {
673        let (x_lo, x_hi) = Extend::<i8>::extend(self, x);
674
675        let i16_ops = self.i16();
676        let (y_lo, y_hi) = (
677            i16_ops.shift_left::<SHIFT>(x_lo),
678            i16_ops.shift_left::<SHIFT>(x_hi),
679        );
680
681        self.narrow_truncate(y_lo, y_hi)
682    }
683
684    #[inline]
685    fn shift_right<const SHIFT: i32>(self, x: I8x64) -> I8x64 {
686        let (x_lo, x_hi) = Extend::<i8>::extend(self, x);
687
688        let i16_ops = self.i16();
689        let (y_lo, y_hi) = (
690            i16_ops.shift_right::<SHIFT>(x_lo),
691            i16_ops.shift_right::<SHIFT>(x_hi),
692        );
693
694        self.narrow_truncate(y_lo, y_hi)
695    }
696}
697
698impl SignedIntOps<i8> for Avx512Isa {
699    #[inline]
700    fn neg(self, x: I8x64) -> I8x64 {
701        unsafe { _mm512_sub_epi8(_mm512_setzero_si512(), x.0) }.into()
702    }
703}
704
705#[inline]
706fn interleave_low_x8(a: __m512i, b: __m512i) -> __m512i {
707    unsafe {
708        // AB{N} = Interleaved Nth 64-bit block.
709        let lo = _mm512_unpacklo_epi8(a, b); // AB0 AB2 AB4 AB6
710        let hi = _mm512_unpackhi_epi8(a, b); // AB1 AB3 AB5 AB7
711        let idx = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
712        _mm512_permutex2var_epi32(lo, idx, hi) // AB0 AB1 AB2 AB3
713    }
714}
715
716#[inline]
717fn interleave_high_x8(a: __m512i, b: __m512i) -> __m512i {
718    unsafe {
719        // AB{N} = Interleaved Nth 64-bit block.
720        let lo = _mm512_unpacklo_epi8(a, b); // AB0 AB2 AB4 AB6
721        let hi = _mm512_unpackhi_epi8(a, b); // AB1 AB3 AB5 AB7
722        let idx = _mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
723        _mm512_permutex2var_epi32(lo, idx, hi) // AB4 AB5 AB6 AB7
724    }
725}
726
727impl Interleave<i8> for Avx512Isa {
728    #[inline]
729    fn interleave_low(self, a: I8x64, b: I8x64) -> I8x64 {
730        interleave_low_x8(a.0, b.0).into()
731    }
732
733    #[inline]
734    fn interleave_high(self, a: I8x64, b: I8x64) -> I8x64 {
735        interleave_high_x8(a.0, b.0).into()
736    }
737}
738
739unsafe impl NumOps<u8> for Avx512Isa {
740    simd_ops_common!(U8x64, __mmask64);
741    simd_int_ops_common!(U8x64);
742
743    #[inline]
744    fn add(self, x: U8x64, y: U8x64) -> U8x64 {
745        unsafe { _mm512_add_epi8(x.0, y.0) }.into()
746    }
747
748    #[inline]
749    fn sub(self, x: U8x64, y: U8x64) -> U8x64 {
750        unsafe { _mm512_sub_epi8(x.0, y.0) }.into()
751    }
752
753    #[inline]
754    fn mul(self, x: U8x64, y: U8x64) -> U8x64 {
755        let (x_lo, x_hi) = Extend::<u8>::extend(self, x);
756        let (y_lo, y_hi) = Extend::<u8>::extend(self, y);
757
758        let u16_ops = self.u16();
759        let prod_lo = u16_ops.mul(x_lo, y_lo);
760        let prod_hi = u16_ops.mul(x_hi, y_hi);
761
762        self.narrow_truncate(prod_lo, prod_hi)
763    }
764
765    #[inline]
766    fn splat(self, x: u8) -> U8x64 {
767        unsafe { _mm512_set1_epi8(x as i8) }.into()
768    }
769
770    #[inline]
771    fn eq(self, x: U8x64, y: U8x64) -> __mmask64 {
772        unsafe { _mm512_cmpeq_epu8_mask(x.0, y.0) }
773    }
774
775    #[inline]
776    fn ge(self, x: U8x64, y: U8x64) -> __mmask64 {
777        unsafe { _mm512_cmpge_epu8_mask(x.0, y.0) }
778    }
779
780    #[inline]
781    fn gt(self, x: U8x64, y: U8x64) -> __mmask64 {
782        unsafe { _mm512_cmpgt_epu8_mask(x.0, y.0) }
783    }
784
785    #[inline]
786    unsafe fn load_ptr(self, ptr: *const u8) -> U8x64 {
787        unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
788    }
789
790    #[inline]
791    fn select(self, x: U8x64, y: U8x64, mask: <U8x64 as Simd>::Mask) -> U8x64 {
792        unsafe { _mm512_mask_blend_epi8(mask, y.0, x.0) }.into()
793    }
794
795    #[inline]
796    unsafe fn store_ptr(self, x: U8x64, ptr: *mut u8) {
797        unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
798    }
799
800    #[inline]
801    unsafe fn load_ptr_mask(self, ptr: *const u8, mask: __mmask64) -> U8x64 {
802        unsafe { _mm512_mask_loadu_epi8(_mm512_set1_epi8(0), mask, ptr as *const i8) }.into()
803    }
804
805    #[inline]
806    unsafe fn store_ptr_mask(self, x: U8x64, ptr: *mut u8, mask: __mmask64) {
807        unsafe { _mm512_mask_storeu_epi8(ptr as *mut i8, mask, x.0) }
808    }
809}
810
811impl Extend<i16> for Avx512Isa {
812    type Output = I32x16;
813
814    #[inline]
815    fn extend(self, x: I16x32) -> (Self::Output, Self::Output) {
816        unsafe {
817            let lo = _mm512_extracti64x4_epi64(x.0, 0);
818            let lo = _mm512_cvtepi16_epi32(lo);
819
820            let hi = _mm512_extracti64x4_epi64(x.0, 1);
821            let hi = _mm512_cvtepi16_epi32(hi);
822            (lo.into(), hi.into())
823        }
824    }
825}
826
827impl Extend<i8> for Avx512Isa {
828    type Output = I16x32;
829
830    #[inline]
831    fn extend(self, x: I8x64) -> (I16x32, I16x32) {
832        unsafe {
833            let lo = _mm512_extracti64x4_epi64(x.0, 0);
834            let lo = _mm512_cvtepi8_epi16(lo);
835
836            let hi = _mm512_extracti64x4_epi64(x.0, 1);
837            let hi = _mm512_cvtepi8_epi16(hi);
838            (lo.into(), hi.into())
839        }
840    }
841}
842
843impl Extend<u8> for Avx512Isa {
844    type Output = U16x32;
845
846    #[inline]
847    fn extend(self, x: U8x64) -> (U16x32, U16x32) {
848        unsafe {
849            let lo = _mm512_extracti64x4_epi64(x.0, 0);
850            let lo = _mm512_cvtepu8_epi16(lo);
851
852            let hi = _mm512_extracti64x4_epi64(x.0, 1);
853            let hi = _mm512_cvtepu8_epi16(hi);
854            (lo.into(), hi.into())
855        }
856    }
857}
858
859impl IntOps<u8> for Avx512Isa {
860    #[inline]
861    fn shift_left<const SHIFT: i32>(self, x: U8x64) -> U8x64 {
862        let (x_lo, x_hi) = Extend::<u8>::extend(self, x);
863
864        let u16_ops = self.u16();
865        let (y_lo, y_hi) = (
866            u16_ops.shift_left::<SHIFT>(x_lo),
867            u16_ops.shift_left::<SHIFT>(x_hi),
868        );
869
870        self.narrow_truncate(y_lo, y_hi)
871    }
872
873    #[inline]
874    fn shift_right<const SHIFT: i32>(self, x: U8x64) -> U8x64 {
875        let (x_lo, x_hi) = Extend::<u8>::extend(self, x);
876
877        let u16_ops = self.u16();
878        let (y_lo, y_hi) = (
879            u16_ops.shift_right::<SHIFT>(x_lo),
880            u16_ops.shift_right::<SHIFT>(x_hi),
881        );
882
883        self.narrow_truncate(y_lo, y_hi)
884    }
885}
886
887impl Interleave<u8> for Avx512Isa {
888    #[inline]
889    fn interleave_low(self, a: U8x64, b: U8x64) -> U8x64 {
890        unsafe {
891            // AB{N} = Interleaved Nth 64-bit block.
892            let lo = _mm512_unpacklo_epi8(a.0, b.0); // AB0 AB2 AB4 AB6
893            let hi = _mm512_unpackhi_epi8(a.0, b.0); // AB1 AB3 AB5 AB7
894            let idx = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
895            _mm512_permutex2var_epi32(lo, idx, hi) // AB0 AB1 AB2 AB3
896        }
897        .into()
898    }
899
900    #[inline]
901    fn interleave_high(self, a: U8x64, b: U8x64) -> U8x64 {
902        unsafe {
903            // AB{N} = Interleaved Nth 64-bit block.
904            let lo = _mm512_unpacklo_epi8(a.0, b.0); // AB0 AB2 AB4 AB6
905            let hi = _mm512_unpackhi_epi8(a.0, b.0); // AB1 AB3 AB5 AB7
906            let idx =
907                _mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
908            _mm512_permutex2var_epi32(lo, idx, hi) // AB4 AB5 AB6 AB7
909        }
910        .into()
911    }
912}
913
914impl Narrow<I16x32> for Avx512Isa {
915    type Output = I8x64;
916
917    #[inline]
918    fn narrow_truncate(self, a: I16x32, b: I16x32) -> I8x64 {
919        let y = unsafe {
920            let lo_i8 = _mm512_cvtepi16_epi8(a.0);
921            let hi_i8 = _mm512_cvtepi16_epi8(b.0);
922            _mm512_inserti64x4(_mm512_castsi256_si512(lo_i8), hi_i8, 1)
923        };
924        I8x64(y)
925    }
926}
927
928impl Narrow<U16x32> for Avx512Isa {
929    type Output = U8x64;
930
931    #[inline]
932    fn narrow_truncate(self, a: U16x32, b: U16x32) -> U8x64 {
933        let y = unsafe {
934            let lo_u8 = _mm512_cvtepi16_epi8(a.0);
935            let hi_u8 = _mm512_cvtepi16_epi8(b.0);
936            _mm512_inserti64x4(_mm512_castsi256_si512(lo_u8), hi_u8, 1)
937        };
938        U8x64(y)
939    }
940}
941
942unsafe impl NumOps<u16> for Avx512Isa {
943    simd_ops_common!(U16x32, __mmask32);
944    simd_int_ops_common!(U16x32);
945
946    #[inline]
947    fn add(self, x: U16x32, y: U16x32) -> U16x32 {
948        unsafe { _mm512_add_epi16(x.0, y.0) }.into()
949    }
950
951    #[inline]
952    fn sub(self, x: U16x32, y: U16x32) -> U16x32 {
953        unsafe { _mm512_sub_epi16(x.0, y.0) }.into()
954    }
955
956    #[inline]
957    fn mul(self, x: U16x32, y: U16x32) -> U16x32 {
958        unsafe { _mm512_mullo_epi16(x.0, y.0) }.into()
959    }
960
961    #[inline]
962    fn splat(self, x: u16) -> U16x32 {
963        unsafe { _mm512_set1_epi16(x as i16) }.into()
964    }
965
966    #[inline]
967    fn eq(self, x: U16x32, y: U16x32) -> __mmask32 {
968        unsafe { _mm512_cmp_epu16_mask(x.0, y.0, _MM_CMPINT_EQ) }
969    }
970
971    #[inline]
972    fn ge(self, x: U16x32, y: U16x32) -> __mmask32 {
973        unsafe { _mm512_cmp_epu16_mask(x.0, y.0, _MM_CMPINT_NLT) }
974    }
975
976    #[inline]
977    fn gt(self, x: U16x32, y: U16x32) -> __mmask32 {
978        unsafe { _mm512_cmp_epu16_mask(x.0, y.0, _MM_CMPINT_NLE) }
979    }
980
981    #[inline]
982    unsafe fn load_ptr(self, ptr: *const u16) -> U16x32 {
983        unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
984    }
985
986    #[inline]
987    fn select(self, x: U16x32, y: U16x32, mask: <U16x32 as Simd>::Mask) -> U16x32 {
988        unsafe { _mm512_mask_blend_epi16(mask, y.0, x.0) }.into()
989    }
990
991    #[inline]
992    unsafe fn store_ptr(self, x: U16x32, ptr: *mut u16) {
993        unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
994    }
995
996    #[inline]
997    unsafe fn load_ptr_mask(self, ptr: *const u16, mask: __mmask32) -> U16x32 {
998        unsafe { _mm512_mask_loadu_epi16(_mm512_set1_epi16(0), mask, ptr as *const i16) }.into()
999    }
1000
1001    #[inline]
1002    unsafe fn store_ptr_mask(self, x: U16x32, ptr: *mut u16, mask: __mmask32) {
1003        unsafe { _mm512_mask_storeu_epi16(ptr as *mut i16, mask, x.0) }
1004    }
1005}
1006
1007impl IntOps<u16> for Avx512Isa {
1008    #[inline]
1009    fn shift_left<const SHIFT: i32>(self, x: U16x32) -> U16x32 {
1010        let count: I16x32 = self.splat(SHIFT as i16);
1011        unsafe { _mm512_sllv_epi16(x.0, count.0) }.into()
1012    }
1013
1014    #[inline]
1015    fn shift_right<const SHIFT: i32>(self, x: U16x32) -> U16x32 {
1016        let count: I16x32 = self.splat(SHIFT as i16);
1017        unsafe { _mm512_srlv_epi16(x.0, count.0) }.into()
1018    }
1019}
1020
1021macro_rules! impl_mask {
1022    ($mask:ty) => {
1023        impl Mask for $mask {
1024            type Array = [bool; size_of::<$mask>() * 8];
1025
1026            #[inline]
1027            fn to_array(self) -> Self::Array {
1028                std::array::from_fn(|i| self & (1 << i) != 0)
1029            }
1030        }
1031
1032        unsafe impl MaskOps<$mask> for Avx512Isa {
1033            #[inline]
1034            fn and(self, x: $mask, y: $mask) -> $mask {
1035                x & y
1036            }
1037
1038            #[inline]
1039            fn any(self, x: $mask) -> bool {
1040                x != 0
1041            }
1042
1043            #[inline]
1044            fn all(self, x: $mask) -> bool {
1045                x == !0
1046            }
1047        }
1048    };
1049}
1050
1051impl_mask!(__mmask16);
1052impl_mask!(__mmask32);
1053impl_mask!(__mmask64);