1use std::arch::x86_64::{
2 __m512, __m512i, __mmask16, __mmask32, __mmask64, _CMP_EQ_OQ, _CMP_GE_OQ, _CMP_GT_OQ,
3 _CMP_LE_OQ, _CMP_LT_OQ, _MM_CMPINT_EQ, _MM_CMPINT_NLE, _MM_CMPINT_NLT,
4 _MM_FROUND_TO_NEAREST_INT, _MM_HINT_ET0, _MM_HINT_T0, _mm_prefetch, _mm512_add_epi8,
5 _mm512_add_epi16, _mm512_add_epi32, _mm512_add_ps, _mm512_and_ps, _mm512_and_si512,
6 _mm512_andnot_ps, _mm512_andnot_si512, _mm512_castsi256_si512, _mm512_castsi512_si256,
7 _mm512_cmp_epi16_mask, _mm512_cmp_epi32_mask, _mm512_cmp_epu16_mask, _mm512_cmp_ps_mask,
8 _mm512_cmpeq_epi8_mask, _mm512_cmpeq_epu8_mask, _mm512_cmpge_epi8_mask, _mm512_cmpge_epu8_mask,
9 _mm512_cmpgt_epi8_mask, _mm512_cmpgt_epu8_mask, _mm512_cvtepi8_epi16, _mm512_cvtepi16_epi8,
10 _mm512_cvtepi16_epi32, _mm512_cvtepi32_ps, _mm512_cvtepu8_epi16, _mm512_cvtps_epi32,
11 _mm512_cvttps_epi32, _mm512_div_ps, _mm512_extracti64x4_epi64, _mm512_fmadd_ps,
12 _mm512_fnmadd_ps, _mm512_inserti64x4, _mm512_loadu_ps, _mm512_loadu_si512,
13 _mm512_mask_blend_epi8, _mm512_mask_blend_epi16, _mm512_mask_blend_epi32, _mm512_mask_blend_ps,
14 _mm512_mask_loadu_epi8, _mm512_mask_loadu_epi16, _mm512_mask_loadu_epi32, _mm512_mask_loadu_ps,
15 _mm512_mask_storeu_epi8, _mm512_mask_storeu_epi16, _mm512_mask_storeu_epi32,
16 _mm512_mask_storeu_ps, _mm512_max_ps, _mm512_min_ps, _mm512_mul_ps, _mm512_mullo_epi16,
17 _mm512_mullo_epi32, _mm512_or_ps, _mm512_or_si512, _mm512_packs_epi32, _mm512_packus_epi16,
18 _mm512_permutex2var_epi32, _mm512_permutexvar_epi64, _mm512_reduce_add_ps,
19 _mm512_roundscale_ps, _mm512_set1_epi8, _mm512_set1_epi16, _mm512_set1_epi32, _mm512_set1_ps,
20 _mm512_setr_epi32, _mm512_setr_epi64, _mm512_setzero_si512, _mm512_sllv_epi16,
21 _mm512_sllv_epi32, _mm512_srav_epi16, _mm512_srav_epi32, _mm512_srlv_epi16, _mm512_storeu_ps,
22 _mm512_storeu_si512, _mm512_sub_epi8, _mm512_sub_epi16, _mm512_sub_epi32, _mm512_sub_ps,
23 _mm512_unpackhi_epi8, _mm512_unpackhi_epi16, _mm512_unpacklo_epi8, _mm512_unpacklo_epi16,
24 _mm512_xor_ps, _mm512_xor_si512,
25};
26use std::mem::transmute;
27
28use super::super::{lanes, simd_type};
29use crate::ops::{
30 Concat, Extend, FloatOps, IntOps, Interleave, MaskOps, Narrow, NarrowSaturate, NumOps,
31 SignedIntOps, ToFloat,
32};
33use crate::{Isa, Mask, Simd};
34
35simd_type!(F32x16, __m512, f32, __mmask16, Avx512Isa);
36simd_type!(I32x16, __m512i, i32, __mmask16, Avx512Isa);
37simd_type!(I16x32, __m512i, i16, __mmask32, Avx512Isa);
38simd_type!(I8x64, __m512i, i8, __mmask64, Avx512Isa);
39simd_type!(U8x64, __m512i, u8, __mmask64, Avx512Isa);
40simd_type!(U16x32, __m512i, u16, __mmask32, Avx512Isa);
41simd_type!(U32x16, __m512i, u32, __mmask16, Avx512Isa);
42
43#[derive(Copy, Clone)]
44pub struct Avx512Isa {
45 _private: (),
46}
47
48impl Avx512Isa {
49 pub fn new() -> Option<Self> {
50 if crate::is_avx512_supported() {
51 Some(Avx512Isa { _private: () })
52 } else {
53 None
54 }
55 }
56}
57
58unsafe impl Isa for Avx512Isa {
60 type M32 = __mmask16;
61 type M16 = __mmask32;
62 type M8 = __mmask64;
63 type F32 = F32x16;
64 type I32 = I32x16;
65 type I16 = I16x32;
66 type I8 = I8x64;
67 type U8 = U8x64;
68 type U16 = U16x32;
69 type U32 = U32x16;
70 type Bits = I32x16;
71
72 fn f32(self) -> impl FloatOps<f32, Simd = Self::F32, Int = Self::I32> {
73 self
74 }
75
76 fn i32(
77 self,
78 ) -> impl SignedIntOps<i32, Simd = Self::I32>
79 + NarrowSaturate<i32, i16, Output = Self::I16>
80 + Concat<i32>
81 + ToFloat<i32, Output = Self::F32> {
82 self
83 }
84
85 fn i16(
86 self,
87 ) -> impl SignedIntOps<i16, Simd = Self::I16>
88 + NarrowSaturate<i16, u8, Output = Self::U8>
89 + Extend<i16, Output = Self::I32>
90 + Interleave<i16> {
91 self
92 }
93
94 fn i8(
95 self,
96 ) -> impl SignedIntOps<i8, Simd = Self::I8> + Extend<i8, Output = Self::I16> + Interleave<i8>
97 {
98 self
99 }
100
101 fn u8(
102 self,
103 ) -> impl IntOps<u8, Simd = Self::U8> + Extend<u8, Output = Self::U16> + Interleave<u8> {
104 self
105 }
106
107 fn u16(self) -> impl IntOps<u16, Simd = Self::U16> {
108 self
109 }
110
111 fn m32(self) -> impl MaskOps<Self::M32> {
112 self
113 }
114
115 fn m16(self) -> impl MaskOps<Self::M16> {
116 self
117 }
118
119 fn m8(self) -> impl MaskOps<Self::M8> {
120 self
121 }
122}
123
124macro_rules! simd_ops_common {
125 ($simd:ty, $mask:ty) => {
126 type Simd = $simd;
127
128 #[inline]
129 fn len(self) -> usize {
130 lanes::<$simd>()
131 }
132
133 #[inline]
134 fn first_n_mask(self, n: usize) -> $mask {
135 let mut mask = 0;
136 for i in 0..n {
137 mask |= 1 << i;
138 }
139 mask
140 }
141
142 #[inline]
143 fn prefetch(self, ptr: *const <$simd as Simd>::Elem) {
144 unsafe { _mm_prefetch(ptr as *const i8, _MM_HINT_T0) }
145 }
146
147 #[inline]
148 fn prefetch_write(self, ptr: *mut <$simd as Simd>::Elem) {
149 unsafe { _mm_prefetch(ptr as *const i8, _MM_HINT_ET0) }
150 }
151 };
152}
153
154macro_rules! simd_int_ops_common {
155 ($simd:ty) => {
156 #[inline]
157 fn and(self, x: $simd, y: $simd) -> $simd {
158 unsafe { _mm512_and_si512(x.0, y.0) }.into()
159 }
160
161 #[inline]
162 fn or(self, x: $simd, y: $simd) -> $simd {
163 unsafe { _mm512_or_si512(x.0, y.0) }.into()
164 }
165
166 #[inline]
167 fn xor(self, x: $simd, y: $simd) -> $simd {
168 unsafe { _mm512_xor_si512(x.0, y.0) }.into()
169 }
170
171 #[inline]
172 fn not(self, x: $simd) -> $simd {
173 unsafe { _mm512_andnot_si512(x.0, _mm512_set1_epi8(-1)) }.into()
174 }
175 };
176}
177
178unsafe impl NumOps<f32> for Avx512Isa {
179 simd_ops_common!(F32x16, __mmask16);
180
181 #[inline]
182 fn add(self, x: F32x16, y: F32x16) -> F32x16 {
183 unsafe { _mm512_add_ps(x.0, y.0) }.into()
184 }
185
186 #[inline]
187 fn sub(self, x: F32x16, y: F32x16) -> F32x16 {
188 unsafe { _mm512_sub_ps(x.0, y.0) }.into()
189 }
190
191 #[inline]
192 fn mul(self, x: F32x16, y: F32x16) -> F32x16 {
193 unsafe { _mm512_mul_ps(x.0, y.0) }.into()
194 }
195
196 #[inline]
197 fn mul_add(self, a: F32x16, b: F32x16, c: F32x16) -> F32x16 {
198 unsafe { _mm512_fmadd_ps(a.0, b.0, c.0) }.into()
199 }
200
201 #[inline]
202 fn lt(self, x: F32x16, y: F32x16) -> __mmask16 {
203 unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_LT_OQ) }
204 }
205
206 #[inline]
207 fn le(self, x: F32x16, y: F32x16) -> __mmask16 {
208 unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_LE_OQ) }
209 }
210
211 #[inline]
212 fn eq(self, x: F32x16, y: F32x16) -> __mmask16 {
213 unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_EQ_OQ) }
214 }
215
216 #[inline]
217 fn ge(self, x: F32x16, y: F32x16) -> __mmask16 {
218 unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_GE_OQ) }
219 }
220
221 #[inline]
222 fn gt(self, x: F32x16, y: F32x16) -> __mmask16 {
223 unsafe { _mm512_cmp_ps_mask(x.0, y.0, _CMP_GT_OQ) }
224 }
225
226 #[inline]
227 fn min(self, x: F32x16, y: F32x16) -> F32x16 {
228 unsafe { _mm512_min_ps(x.0, y.0) }.into()
229 }
230
231 #[inline]
232 fn max(self, x: F32x16, y: F32x16) -> F32x16 {
233 unsafe { _mm512_max_ps(x.0, y.0) }.into()
234 }
235
236 #[inline]
237 fn and(self, x: F32x16, y: F32x16) -> F32x16 {
238 unsafe { _mm512_and_ps(x.0, y.0) }.into()
239 }
240
241 #[inline]
242 fn not(self, x: F32x16) -> F32x16 {
243 let all_ones: F32x16 = self.splat(f32::from_bits(0xFFFFFFFF));
244 unsafe { _mm512_andnot_ps(x.0, all_ones.0) }.into()
245 }
246
247 #[inline]
248 fn or(self, x: F32x16, y: F32x16) -> F32x16 {
249 unsafe { _mm512_or_ps(x.0, y.0) }.into()
250 }
251
252 #[inline]
253 fn xor(self, x: F32x16, y: F32x16) -> F32x16 {
254 unsafe { _mm512_xor_ps(x.0, y.0) }.into()
255 }
256
257 #[inline]
258 fn splat(self, x: f32) -> F32x16 {
259 unsafe { _mm512_set1_ps(x) }.into()
260 }
261
262 #[inline]
263 unsafe fn load_ptr(self, ptr: *const f32) -> F32x16 {
264 unsafe { _mm512_loadu_ps(ptr) }.into()
265 }
266
267 #[inline]
268 fn select(self, x: F32x16, y: F32x16, mask: <F32x16 as Simd>::Mask) -> F32x16 {
269 unsafe { _mm512_mask_blend_ps(mask, y.0, x.0) }.into()
270 }
271
272 #[inline]
273 unsafe fn load_ptr_mask(self, ptr: *const f32, mask: __mmask16) -> F32x16 {
274 unsafe { _mm512_mask_loadu_ps(_mm512_set1_ps(0.), mask, ptr) }.into()
275 }
276
277 #[inline]
278 unsafe fn store_ptr_mask(self, x: F32x16, ptr: *mut f32, mask: __mmask16) {
279 unsafe { _mm512_mask_storeu_ps(ptr, mask, x.0) }
280 }
281
282 #[inline]
283 unsafe fn store_ptr(self, x: F32x16, ptr: *mut f32) {
284 unsafe { _mm512_storeu_ps(ptr, x.0) }
285 }
286
287 #[inline]
288 fn sum(self, x: F32x16) -> f32 {
289 unsafe { _mm512_reduce_add_ps(x.0) }
290 }
291}
292
293impl FloatOps<f32> for Avx512Isa {
294 type Int = <Self as Isa>::I32;
295
296 #[inline]
297 fn div(self, x: F32x16, y: F32x16) -> F32x16 {
298 unsafe { _mm512_div_ps(x.0, y.0) }.into()
299 }
300
301 #[inline]
302 fn abs(self, x: F32x16) -> F32x16 {
303 unsafe { _mm512_andnot_ps(_mm512_set1_ps(-0.0), x.0) }.into()
304 }
305
306 #[inline]
307 fn neg(self, x: F32x16) -> F32x16 {
308 unsafe { _mm512_xor_ps(x.0, _mm512_set1_ps(-0.0)) }.into()
309 }
310
311 #[inline]
312 fn mul_sub_from(self, a: F32x16, b: F32x16, c: F32x16) -> F32x16 {
313 unsafe { _mm512_fnmadd_ps(a.0, b.0, c.0) }.into()
314 }
315
316 #[inline]
317 fn round_ties_even(self, x: F32x16) -> F32x16 {
318 unsafe { _mm512_roundscale_ps(x.0, _MM_FROUND_TO_NEAREST_INT) }.into()
319 }
320
321 #[inline]
322 fn to_int_trunc(self, x: F32x16) -> Self::Int {
323 unsafe { _mm512_cvttps_epi32(x.0) }.into()
324 }
325
326 #[inline]
327 fn to_int_round(self, x: F32x16) -> Self::Int {
328 unsafe { _mm512_cvtps_epi32(x.0) }.into()
329 }
330}
331
332unsafe impl NumOps<i32> for Avx512Isa {
333 simd_ops_common!(I32x16, __mmask16);
334 simd_int_ops_common!(I32x16);
335
336 #[inline]
337 fn add(self, x: I32x16, y: I32x16) -> I32x16 {
338 unsafe { _mm512_add_epi32(x.0, y.0) }.into()
339 }
340
341 #[inline]
342 fn sub(self, x: I32x16, y: I32x16) -> I32x16 {
343 unsafe { _mm512_sub_epi32(x.0, y.0) }.into()
344 }
345
346 #[inline]
347 fn mul(self, x: I32x16, y: I32x16) -> I32x16 {
348 unsafe { _mm512_mullo_epi32(x.0, y.0) }.into()
349 }
350
351 #[inline]
352 fn splat(self, x: i32) -> I32x16 {
353 unsafe { _mm512_set1_epi32(x) }.into()
354 }
355
356 #[inline]
357 fn eq(self, x: I32x16, y: I32x16) -> __mmask16 {
358 unsafe { _mm512_cmp_epi32_mask(x.0, y.0, _MM_CMPINT_EQ) }
359 }
360
361 #[inline]
362 fn ge(self, x: I32x16, y: I32x16) -> __mmask16 {
363 unsafe { _mm512_cmp_epi32_mask(x.0, y.0, _MM_CMPINT_NLT) }
364 }
365
366 #[inline]
367 fn gt(self, x: I32x16, y: I32x16) -> __mmask16 {
368 unsafe { _mm512_cmp_epi32_mask(x.0, y.0, _MM_CMPINT_NLE) }
369 }
370
371 #[inline]
372 unsafe fn load_ptr(self, ptr: *const i32) -> I32x16 {
373 unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
374 }
375
376 #[inline]
377 fn select(self, x: I32x16, y: I32x16, mask: <I32x16 as Simd>::Mask) -> I32x16 {
378 unsafe { _mm512_mask_blend_epi32(mask, y.0, x.0) }.into()
379 }
380
381 #[inline]
382 unsafe fn store_ptr(self, x: I32x16, ptr: *mut i32) {
383 unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
384 }
385
386 #[inline]
387 unsafe fn load_ptr_mask(self, ptr: *const i32, mask: __mmask16) -> I32x16 {
388 unsafe { _mm512_mask_loadu_epi32(_mm512_set1_epi32(0), mask, ptr) }.into()
389 }
390
391 #[inline]
392 unsafe fn store_ptr_mask(self, x: I32x16, ptr: *mut i32, mask: __mmask16) {
393 unsafe { _mm512_mask_storeu_epi32(ptr, mask, x.0) }
394 }
395}
396
397impl IntOps<i32> for Avx512Isa {
398 #[inline]
399 fn shift_left<const SHIFT: i32>(self, x: I32x16) -> I32x16 {
400 let count: I32x16 = self.splat(SHIFT);
401 unsafe { _mm512_sllv_epi32(x.0, count.0) }.into()
402 }
403
404 #[inline]
405 fn shift_right<const SHIFT: i32>(self, x: I32x16) -> I32x16 {
406 let count: I32x16 = self.splat(SHIFT);
407 unsafe { _mm512_srav_epi32(x.0, count.0) }.into()
408 }
409}
410
411impl SignedIntOps<i32> for Avx512Isa {
412 #[inline]
413 fn neg(self, x: I32x16) -> I32x16 {
414 unsafe { _mm512_sub_epi32(_mm512_setzero_si512(), x.0) }.into()
415 }
416}
417
418impl NarrowSaturate<i32, i16> for Avx512Isa {
419 type Output = I16x32;
420
421 #[inline]
422 fn narrow_saturate(self, low: I32x16, high: I32x16) -> I16x32 {
423 unsafe {
424 let packed = _mm512_packs_epi32(low.0, high.0);
429 let permutation = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
430 _mm512_permutexvar_epi64(permutation, packed)
431 }
432 .into()
433 }
434}
435
436impl Concat<i32> for Avx512Isa {
437 #[inline]
438 fn concat_low(self, a: I32x16, b: I32x16) -> I32x16 {
439 unsafe {
440 let a_lo = _mm512_castsi512_si256(a.0);
441 let b_lo = _mm512_castsi512_si256(b.0);
442 _mm512_inserti64x4(_mm512_castsi256_si512(a_lo), b_lo, 1)
443 }
444 .into()
445 }
446
447 #[inline]
448 fn concat_high(self, a: I32x16, b: I32x16) -> I32x16 {
449 unsafe {
450 let a_hi = _mm512_extracti64x4_epi64(a.0, 1);
451 let b_hi = _mm512_extracti64x4_epi64(b.0, 1);
452 _mm512_inserti64x4(_mm512_castsi256_si512(a_hi), b_hi, 1)
453 }
454 .into()
455 }
456}
457
458impl ToFloat<i32> for Avx512Isa {
459 type Output = F32x16;
460
461 #[inline]
462 fn to_float(self, x: I32x16) -> F32x16 {
463 unsafe { _mm512_cvtepi32_ps(x.0) }.into()
464 }
465}
466
467unsafe impl NumOps<i16> for Avx512Isa {
468 simd_ops_common!(I16x32, __mmask32);
469 simd_int_ops_common!(I16x32);
470
471 #[inline]
472 fn add(self, x: I16x32, y: I16x32) -> I16x32 {
473 unsafe { _mm512_add_epi16(x.0, y.0) }.into()
474 }
475
476 #[inline]
477 fn sub(self, x: I16x32, y: I16x32) -> I16x32 {
478 unsafe { _mm512_sub_epi16(x.0, y.0) }.into()
479 }
480
481 #[inline]
482 fn mul(self, x: I16x32, y: I16x32) -> I16x32 {
483 unsafe { _mm512_mullo_epi16(x.0, y.0) }.into()
484 }
485
486 #[inline]
487 fn splat(self, x: i16) -> I16x32 {
488 unsafe { _mm512_set1_epi16(x) }.into()
489 }
490
491 #[inline]
492 fn eq(self, x: I16x32, y: I16x32) -> __mmask32 {
493 unsafe { _mm512_cmp_epi16_mask(x.0, y.0, _MM_CMPINT_EQ) }
494 }
495
496 #[inline]
497 fn ge(self, x: I16x32, y: I16x32) -> __mmask32 {
498 unsafe { _mm512_cmp_epi16_mask(x.0, y.0, _MM_CMPINT_NLT) }
499 }
500
501 #[inline]
502 fn gt(self, x: I16x32, y: I16x32) -> __mmask32 {
503 unsafe { _mm512_cmp_epi16_mask(x.0, y.0, _MM_CMPINT_NLE) }
504 }
505
506 #[inline]
507 unsafe fn load_ptr(self, ptr: *const i16) -> I16x32 {
508 unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
509 }
510
511 #[inline]
512 fn select(self, x: I16x32, y: I16x32, mask: <I16x32 as Simd>::Mask) -> I16x32 {
513 unsafe { _mm512_mask_blend_epi16(mask, y.0, x.0) }.into()
514 }
515
516 #[inline]
517 unsafe fn store_ptr(self, x: I16x32, ptr: *mut i16) {
518 unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
519 }
520
521 #[inline]
522 unsafe fn load_ptr_mask(self, ptr: *const i16, mask: __mmask32) -> I16x32 {
523 unsafe { _mm512_mask_loadu_epi16(_mm512_set1_epi16(0), mask, ptr) }.into()
524 }
525
526 #[inline]
527 unsafe fn store_ptr_mask(self, x: I16x32, ptr: *mut i16, mask: __mmask32) {
528 unsafe { _mm512_mask_storeu_epi16(ptr, mask, x.0) }
529 }
530}
531
532impl IntOps<i16> for Avx512Isa {
533 #[inline]
534 fn shift_left<const SHIFT: i32>(self, x: I16x32) -> I16x32 {
535 let count: I16x32 = self.splat(SHIFT as i16);
536 unsafe { _mm512_sllv_epi16(x.0, count.0) }.into()
537 }
538
539 #[inline]
540 fn shift_right<const SHIFT: i32>(self, x: I16x32) -> I16x32 {
541 let count: I16x32 = self.splat(SHIFT as i16);
542 unsafe { _mm512_srav_epi16(x.0, count.0) }.into()
543 }
544}
545
546impl SignedIntOps<i16> for Avx512Isa {
547 #[inline]
548 fn neg(self, x: I16x32) -> I16x32 {
549 unsafe { _mm512_sub_epi16(_mm512_setzero_si512(), x.0) }.into()
550 }
551}
552
553impl NarrowSaturate<i16, u8> for Avx512Isa {
554 type Output = U8x64;
555
556 #[inline]
557 fn narrow_saturate(self, low: I16x32, high: I16x32) -> U8x64 {
558 unsafe {
559 let packed = _mm512_packus_epi16(low.0, high.0);
564 let permutation = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7);
565 _mm512_permutexvar_epi64(permutation, packed)
566 }
567 .into()
568 }
569}
570
571impl Interleave<i16> for Avx512Isa {
572 #[inline]
573 fn interleave_low(self, a: I16x32, b: I16x32) -> I16x32 {
574 unsafe {
575 let lo = _mm512_unpacklo_epi16(a.0, b.0); let hi = _mm512_unpackhi_epi16(a.0, b.0); let idx = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
579 _mm512_permutex2var_epi32(lo, idx, hi) }
581 .into()
582 }
583
584 #[inline]
585 fn interleave_high(self, a: I16x32, b: I16x32) -> I16x32 {
586 unsafe {
587 let lo = _mm512_unpacklo_epi16(a.0, b.0); let hi = _mm512_unpackhi_epi16(a.0, b.0); let idx =
591 _mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
592 _mm512_permutex2var_epi32(lo, idx, hi) }
594 .into()
595 }
596}
597
598unsafe impl NumOps<i8> for Avx512Isa {
599 simd_ops_common!(I8x64, __mmask64);
600 simd_int_ops_common!(I8x64);
601
602 #[inline]
603 fn add(self, x: I8x64, y: I8x64) -> I8x64 {
604 unsafe { _mm512_add_epi8(x.0, y.0) }.into()
605 }
606
607 #[inline]
608 fn sub(self, x: I8x64, y: I8x64) -> I8x64 {
609 unsafe { _mm512_sub_epi8(x.0, y.0) }.into()
610 }
611
612 #[inline]
613 fn mul(self, x: I8x64, y: I8x64) -> I8x64 {
614 let (x_lo, x_hi) = Extend::<i8>::extend(self, x);
615 let (y_lo, y_hi) = Extend::<i8>::extend(self, y);
616
617 let i16_ops = self.i16();
618 let prod_lo = i16_ops.mul(x_lo, y_lo);
619 let prod_hi = i16_ops.mul(x_hi, y_hi);
620
621 self.narrow_truncate(prod_lo, prod_hi)
622 }
623
624 #[inline]
625 fn splat(self, x: i8) -> I8x64 {
626 unsafe { _mm512_set1_epi8(x) }.into()
627 }
628
629 #[inline]
630 fn eq(self, x: I8x64, y: I8x64) -> __mmask64 {
631 unsafe { _mm512_cmpeq_epi8_mask(x.0, y.0) }
632 }
633
634 #[inline]
635 fn ge(self, x: I8x64, y: I8x64) -> __mmask64 {
636 unsafe { _mm512_cmpge_epi8_mask(x.0, y.0) }
637 }
638
639 #[inline]
640 fn gt(self, x: I8x64, y: I8x64) -> __mmask64 {
641 unsafe { _mm512_cmpgt_epi8_mask(x.0, y.0) }
642 }
643
644 #[inline]
645 unsafe fn load_ptr(self, ptr: *const i8) -> I8x64 {
646 unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
647 }
648
649 #[inline]
650 fn select(self, x: I8x64, y: I8x64, mask: <I8x64 as Simd>::Mask) -> I8x64 {
651 unsafe { _mm512_mask_blend_epi8(mask, y.0, x.0) }.into()
652 }
653
654 #[inline]
655 unsafe fn store_ptr(self, x: I8x64, ptr: *mut i8) {
656 unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
657 }
658
659 #[inline]
660 unsafe fn load_ptr_mask(self, ptr: *const i8, mask: __mmask64) -> I8x64 {
661 unsafe { _mm512_mask_loadu_epi8(_mm512_set1_epi8(0), mask, ptr) }.into()
662 }
663
664 #[inline]
665 unsafe fn store_ptr_mask(self, x: I8x64, ptr: *mut i8, mask: __mmask64) {
666 unsafe { _mm512_mask_storeu_epi8(ptr, mask, x.0) }
667 }
668}
669
670impl IntOps<i8> for Avx512Isa {
671 #[inline]
672 fn shift_left<const SHIFT: i32>(self, x: I8x64) -> I8x64 {
673 let (x_lo, x_hi) = Extend::<i8>::extend(self, x);
674
675 let i16_ops = self.i16();
676 let (y_lo, y_hi) = (
677 i16_ops.shift_left::<SHIFT>(x_lo),
678 i16_ops.shift_left::<SHIFT>(x_hi),
679 );
680
681 self.narrow_truncate(y_lo, y_hi)
682 }
683
684 #[inline]
685 fn shift_right<const SHIFT: i32>(self, x: I8x64) -> I8x64 {
686 let (x_lo, x_hi) = Extend::<i8>::extend(self, x);
687
688 let i16_ops = self.i16();
689 let (y_lo, y_hi) = (
690 i16_ops.shift_right::<SHIFT>(x_lo),
691 i16_ops.shift_right::<SHIFT>(x_hi),
692 );
693
694 self.narrow_truncate(y_lo, y_hi)
695 }
696}
697
698impl SignedIntOps<i8> for Avx512Isa {
699 #[inline]
700 fn neg(self, x: I8x64) -> I8x64 {
701 unsafe { _mm512_sub_epi8(_mm512_setzero_si512(), x.0) }.into()
702 }
703}
704
705#[inline]
706fn interleave_low_x8(a: __m512i, b: __m512i) -> __m512i {
707 unsafe {
708 let lo = _mm512_unpacklo_epi8(a, b); let hi = _mm512_unpackhi_epi8(a, b); let idx = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
712 _mm512_permutex2var_epi32(lo, idx, hi) }
714}
715
716#[inline]
717fn interleave_high_x8(a: __m512i, b: __m512i) -> __m512i {
718 unsafe {
719 let lo = _mm512_unpacklo_epi8(a, b); let hi = _mm512_unpackhi_epi8(a, b); let idx = _mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
723 _mm512_permutex2var_epi32(lo, idx, hi) }
725}
726
727impl Interleave<i8> for Avx512Isa {
728 #[inline]
729 fn interleave_low(self, a: I8x64, b: I8x64) -> I8x64 {
730 interleave_low_x8(a.0, b.0).into()
731 }
732
733 #[inline]
734 fn interleave_high(self, a: I8x64, b: I8x64) -> I8x64 {
735 interleave_high_x8(a.0, b.0).into()
736 }
737}
738
739unsafe impl NumOps<u8> for Avx512Isa {
740 simd_ops_common!(U8x64, __mmask64);
741 simd_int_ops_common!(U8x64);
742
743 #[inline]
744 fn add(self, x: U8x64, y: U8x64) -> U8x64 {
745 unsafe { _mm512_add_epi8(x.0, y.0) }.into()
746 }
747
748 #[inline]
749 fn sub(self, x: U8x64, y: U8x64) -> U8x64 {
750 unsafe { _mm512_sub_epi8(x.0, y.0) }.into()
751 }
752
753 #[inline]
754 fn mul(self, x: U8x64, y: U8x64) -> U8x64 {
755 let (x_lo, x_hi) = Extend::<u8>::extend(self, x);
756 let (y_lo, y_hi) = Extend::<u8>::extend(self, y);
757
758 let u16_ops = self.u16();
759 let prod_lo = u16_ops.mul(x_lo, y_lo);
760 let prod_hi = u16_ops.mul(x_hi, y_hi);
761
762 self.narrow_truncate(prod_lo, prod_hi)
763 }
764
765 #[inline]
766 fn splat(self, x: u8) -> U8x64 {
767 unsafe { _mm512_set1_epi8(x as i8) }.into()
768 }
769
770 #[inline]
771 fn eq(self, x: U8x64, y: U8x64) -> __mmask64 {
772 unsafe { _mm512_cmpeq_epu8_mask(x.0, y.0) }
773 }
774
775 #[inline]
776 fn ge(self, x: U8x64, y: U8x64) -> __mmask64 {
777 unsafe { _mm512_cmpge_epu8_mask(x.0, y.0) }
778 }
779
780 #[inline]
781 fn gt(self, x: U8x64, y: U8x64) -> __mmask64 {
782 unsafe { _mm512_cmpgt_epu8_mask(x.0, y.0) }
783 }
784
785 #[inline]
786 unsafe fn load_ptr(self, ptr: *const u8) -> U8x64 {
787 unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
788 }
789
790 #[inline]
791 fn select(self, x: U8x64, y: U8x64, mask: <U8x64 as Simd>::Mask) -> U8x64 {
792 unsafe { _mm512_mask_blend_epi8(mask, y.0, x.0) }.into()
793 }
794
795 #[inline]
796 unsafe fn store_ptr(self, x: U8x64, ptr: *mut u8) {
797 unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
798 }
799
800 #[inline]
801 unsafe fn load_ptr_mask(self, ptr: *const u8, mask: __mmask64) -> U8x64 {
802 unsafe { _mm512_mask_loadu_epi8(_mm512_set1_epi8(0), mask, ptr as *const i8) }.into()
803 }
804
805 #[inline]
806 unsafe fn store_ptr_mask(self, x: U8x64, ptr: *mut u8, mask: __mmask64) {
807 unsafe { _mm512_mask_storeu_epi8(ptr as *mut i8, mask, x.0) }
808 }
809}
810
811impl Extend<i16> for Avx512Isa {
812 type Output = I32x16;
813
814 #[inline]
815 fn extend(self, x: I16x32) -> (Self::Output, Self::Output) {
816 unsafe {
817 let lo = _mm512_extracti64x4_epi64(x.0, 0);
818 let lo = _mm512_cvtepi16_epi32(lo);
819
820 let hi = _mm512_extracti64x4_epi64(x.0, 1);
821 let hi = _mm512_cvtepi16_epi32(hi);
822 (lo.into(), hi.into())
823 }
824 }
825}
826
827impl Extend<i8> for Avx512Isa {
828 type Output = I16x32;
829
830 #[inline]
831 fn extend(self, x: I8x64) -> (I16x32, I16x32) {
832 unsafe {
833 let lo = _mm512_extracti64x4_epi64(x.0, 0);
834 let lo = _mm512_cvtepi8_epi16(lo);
835
836 let hi = _mm512_extracti64x4_epi64(x.0, 1);
837 let hi = _mm512_cvtepi8_epi16(hi);
838 (lo.into(), hi.into())
839 }
840 }
841}
842
843impl Extend<u8> for Avx512Isa {
844 type Output = U16x32;
845
846 #[inline]
847 fn extend(self, x: U8x64) -> (U16x32, U16x32) {
848 unsafe {
849 let lo = _mm512_extracti64x4_epi64(x.0, 0);
850 let lo = _mm512_cvtepu8_epi16(lo);
851
852 let hi = _mm512_extracti64x4_epi64(x.0, 1);
853 let hi = _mm512_cvtepu8_epi16(hi);
854 (lo.into(), hi.into())
855 }
856 }
857}
858
859impl IntOps<u8> for Avx512Isa {
860 #[inline]
861 fn shift_left<const SHIFT: i32>(self, x: U8x64) -> U8x64 {
862 let (x_lo, x_hi) = Extend::<u8>::extend(self, x);
863
864 let u16_ops = self.u16();
865 let (y_lo, y_hi) = (
866 u16_ops.shift_left::<SHIFT>(x_lo),
867 u16_ops.shift_left::<SHIFT>(x_hi),
868 );
869
870 self.narrow_truncate(y_lo, y_hi)
871 }
872
873 #[inline]
874 fn shift_right<const SHIFT: i32>(self, x: U8x64) -> U8x64 {
875 let (x_lo, x_hi) = Extend::<u8>::extend(self, x);
876
877 let u16_ops = self.u16();
878 let (y_lo, y_hi) = (
879 u16_ops.shift_right::<SHIFT>(x_lo),
880 u16_ops.shift_right::<SHIFT>(x_hi),
881 );
882
883 self.narrow_truncate(y_lo, y_hi)
884 }
885}
886
887impl Interleave<u8> for Avx512Isa {
888 #[inline]
889 fn interleave_low(self, a: U8x64, b: U8x64) -> U8x64 {
890 unsafe {
891 let lo = _mm512_unpacklo_epi8(a.0, b.0); let hi = _mm512_unpackhi_epi8(a.0, b.0); let idx = _mm512_setr_epi32(0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23);
895 _mm512_permutex2var_epi32(lo, idx, hi) }
897 .into()
898 }
899
900 #[inline]
901 fn interleave_high(self, a: U8x64, b: U8x64) -> U8x64 {
902 unsafe {
903 let lo = _mm512_unpacklo_epi8(a.0, b.0); let hi = _mm512_unpackhi_epi8(a.0, b.0); let idx =
907 _mm512_setr_epi32(8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31);
908 _mm512_permutex2var_epi32(lo, idx, hi) }
910 .into()
911 }
912}
913
914impl Narrow<I16x32> for Avx512Isa {
915 type Output = I8x64;
916
917 #[inline]
918 fn narrow_truncate(self, a: I16x32, b: I16x32) -> I8x64 {
919 let y = unsafe {
920 let lo_i8 = _mm512_cvtepi16_epi8(a.0);
921 let hi_i8 = _mm512_cvtepi16_epi8(b.0);
922 _mm512_inserti64x4(_mm512_castsi256_si512(lo_i8), hi_i8, 1)
923 };
924 I8x64(y)
925 }
926}
927
928impl Narrow<U16x32> for Avx512Isa {
929 type Output = U8x64;
930
931 #[inline]
932 fn narrow_truncate(self, a: U16x32, b: U16x32) -> U8x64 {
933 let y = unsafe {
934 let lo_u8 = _mm512_cvtepi16_epi8(a.0);
935 let hi_u8 = _mm512_cvtepi16_epi8(b.0);
936 _mm512_inserti64x4(_mm512_castsi256_si512(lo_u8), hi_u8, 1)
937 };
938 U8x64(y)
939 }
940}
941
942unsafe impl NumOps<u16> for Avx512Isa {
943 simd_ops_common!(U16x32, __mmask32);
944 simd_int_ops_common!(U16x32);
945
946 #[inline]
947 fn add(self, x: U16x32, y: U16x32) -> U16x32 {
948 unsafe { _mm512_add_epi16(x.0, y.0) }.into()
949 }
950
951 #[inline]
952 fn sub(self, x: U16x32, y: U16x32) -> U16x32 {
953 unsafe { _mm512_sub_epi16(x.0, y.0) }.into()
954 }
955
956 #[inline]
957 fn mul(self, x: U16x32, y: U16x32) -> U16x32 {
958 unsafe { _mm512_mullo_epi16(x.0, y.0) }.into()
959 }
960
961 #[inline]
962 fn splat(self, x: u16) -> U16x32 {
963 unsafe { _mm512_set1_epi16(x as i16) }.into()
964 }
965
966 #[inline]
967 fn eq(self, x: U16x32, y: U16x32) -> __mmask32 {
968 unsafe { _mm512_cmp_epu16_mask(x.0, y.0, _MM_CMPINT_EQ) }
969 }
970
971 #[inline]
972 fn ge(self, x: U16x32, y: U16x32) -> __mmask32 {
973 unsafe { _mm512_cmp_epu16_mask(x.0, y.0, _MM_CMPINT_NLT) }
974 }
975
976 #[inline]
977 fn gt(self, x: U16x32, y: U16x32) -> __mmask32 {
978 unsafe { _mm512_cmp_epu16_mask(x.0, y.0, _MM_CMPINT_NLE) }
979 }
980
981 #[inline]
982 unsafe fn load_ptr(self, ptr: *const u16) -> U16x32 {
983 unsafe { _mm512_loadu_si512(ptr as *const __m512i) }.into()
984 }
985
986 #[inline]
987 fn select(self, x: U16x32, y: U16x32, mask: <U16x32 as Simd>::Mask) -> U16x32 {
988 unsafe { _mm512_mask_blend_epi16(mask, y.0, x.0) }.into()
989 }
990
991 #[inline]
992 unsafe fn store_ptr(self, x: U16x32, ptr: *mut u16) {
993 unsafe { _mm512_storeu_si512(ptr as *mut __m512i, x.0) }
994 }
995
996 #[inline]
997 unsafe fn load_ptr_mask(self, ptr: *const u16, mask: __mmask32) -> U16x32 {
998 unsafe { _mm512_mask_loadu_epi16(_mm512_set1_epi16(0), mask, ptr as *const i16) }.into()
999 }
1000
1001 #[inline]
1002 unsafe fn store_ptr_mask(self, x: U16x32, ptr: *mut u16, mask: __mmask32) {
1003 unsafe { _mm512_mask_storeu_epi16(ptr as *mut i16, mask, x.0) }
1004 }
1005}
1006
1007impl IntOps<u16> for Avx512Isa {
1008 #[inline]
1009 fn shift_left<const SHIFT: i32>(self, x: U16x32) -> U16x32 {
1010 let count: I16x32 = self.splat(SHIFT as i16);
1011 unsafe { _mm512_sllv_epi16(x.0, count.0) }.into()
1012 }
1013
1014 #[inline]
1015 fn shift_right<const SHIFT: i32>(self, x: U16x32) -> U16x32 {
1016 let count: I16x32 = self.splat(SHIFT as i16);
1017 unsafe { _mm512_srlv_epi16(x.0, count.0) }.into()
1018 }
1019}
1020
1021macro_rules! impl_mask {
1022 ($mask:ty) => {
1023 impl Mask for $mask {
1024 type Array = [bool; size_of::<$mask>() * 8];
1025
1026 #[inline]
1027 fn to_array(self) -> Self::Array {
1028 std::array::from_fn(|i| self & (1 << i) != 0)
1029 }
1030 }
1031
1032 unsafe impl MaskOps<$mask> for Avx512Isa {
1033 #[inline]
1034 fn and(self, x: $mask, y: $mask) -> $mask {
1035 x & y
1036 }
1037
1038 #[inline]
1039 fn any(self, x: $mask) -> bool {
1040 x != 0
1041 }
1042
1043 #[inline]
1044 fn all(self, x: $mask) -> bool {
1045 x == !0
1046 }
1047 }
1048 };
1049}
1050
1051impl_mask!(__mmask16);
1052impl_mask!(__mmask32);
1053impl_mask!(__mmask64);