Skip to main content

diskann_wide/arch/x86_64/v3/
masks.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::arch::x86_64::*;
7
8use super::V3;
9use crate::{
10    bitmask::{BitMask, FromInt},
11    doubled,
12    traits::SIMDMask,
13};
14
15///////////////////////
16// AVX2 32-bit masks //
17///////////////////////
18
19// mask8x16
20#[derive(Debug, Clone, Copy)]
21#[allow(non_camel_case_types)]
22#[repr(transparent)]
23pub struct mask8x16(pub(crate) __m128i);
24
25impl SIMDMask for mask8x16 {
26    type Arch = V3;
27    type Underlying = __m128i;
28    type BitMask = BitMask<16, V3>;
29    const ISBITS: bool = false;
30    const LANES: usize = 16;
31
32    #[inline(always)]
33    fn arch(self) -> V3 {
34        // SAFETY: The existence of `self` is proof that we are V3 compatible.
35        unsafe { V3::new() }
36    }
37
38    #[inline(always)]
39    fn to_underlying(self) -> Self::Underlying {
40        self.0
41    }
42
43    #[inline(always)]
44    fn from_underlying(_: V3, value: Self::Underlying) -> Self {
45        Self(value)
46    }
47
48    #[inline(always)]
49    fn keep_first(_: V3, i: usize) -> Self {
50        let i = i.min(Self::LANES);
51        const CMP: [i8; 16] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
52
53        // SAFETY: The `V3` architecture instance is proof that we can use the V3 compatible
54        // intrinsics invoked used here.
55        //
56        // Unaligned load is valid because the array is the correct size.
57        //
58        // This constant local variable is hoisted to a constant in the final binary.
59        // Codegen emits a load to this value so it is relatively cheap to use.
60        unsafe {
61            let c = _mm_loadu_si128(CMP.as_ptr() as *const __m128i);
62
63            // Broadcast the argument across all SIMD lanes, and compare the broadcasted register
64            // with the incremental value in `CMP`.
65            Self(_mm_cmpgt_epi8(_mm_set1_epi8(i as i8), c))
66        }
67    }
68
69    fn get_unchecked(&self, i: usize) -> bool {
70        // This is not particularly efficient.
71        // For bulk checking, users should first convert to a bit-mask and then check.
72        //
73        // Essentially, what this is doing is an entire conversion to a bit-mask to check
74        // a single lane.
75        Into::<Self::BitMask>::into(*self).get_unchecked(i)
76    }
77}
78
79// Conversion back-and-forth.
80// Credit to https://stackoverflow.com/a/72899629 for this algorithm.
81//
82// The gist here is that we load `selector` with
83// ```ignore
84// Lane  |  15   14   13   12   11   10   09   08      07  06   05   04   03   02   01   00
85//       |
86// Value | 0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01 | 0x80 0x40 0x20 0x10 0x08 0x04 0x02 0x01
87// ```
88// Then, we put the 2-bytes of mask into the lower lanes of a `_mm128i`.
89//
90// Using a shuffle, we move the lower byte of the mask into lanes 00 to 07 and the upper
91// byte to lanes 08 to 15 using a shuffle.
92//
93// We then use bit-wise "and" and a comparison to re-create the SIMD mask.
94impl From<BitMask<16, V3>> for mask8x16 {
95    #[inline(always)]
96    fn from(mask: BitMask<16, V3>) -> Self {
97        // Extract the underlying integer.
98        let mask: u16 = mask.0;
99
100        // Masks used for the bit-twiddling.
101        // Select
102        // - bit 7 of byte 7
103        // - bit 6 of byte 6
104        // - bit 5 of byte 5
105        // etc.
106        const BIT_SELECTOR: i64 = 0x8040201008040201u64 as i64;
107
108        // Select byte 0 and broadcast it across 8 bytes.
109        const BROADCAST_BYTE_0: i64 = 0;
110        // Select byte 1 and broadcast it across 8 bytes;
111        const BROADCAST_BYTE_1: i64 = 0x0101010101010101;
112
113        // SAFETY: The `V3` architecture instance in `mask` is proof that we can use the V3
114        // compatible intrinsics invoked used here.
115        unsafe {
116            let selector = _mm_set1_epi64x(BIT_SELECTOR);
117            Self(_mm_cmpeq_epi8(
118                _mm_and_si128(
119                    _mm_shuffle_epi8(
120                        _mm_cvtsi32_si128(mask as i32),
121                        _mm_set_epi64x(BROADCAST_BYTE_1, BROADCAST_BYTE_0),
122                    ),
123                    selector,
124                ),
125                selector,
126            ))
127        }
128    }
129}
130
131impl From<mask8x16> for BitMask<16, V3> {
132    #[inline(always)]
133    fn from(mask: mask8x16) -> Self {
134        let m = mask.to_underlying();
135        // Use an intrinsics to convert the upper bits to an integer bit-mask.
136        // SAFETY: Using intrinsics without touching memory. Invocation of the intrinsic
137        // is gated on successful check of the `cfg` macro.
138        let bitmask: i32 = unsafe { _mm_movemask_epi8(m) };
139        // The intrinsic only sets the lower-bit bits of the returned integer.
140        // We can safely truncate to an 8-bit integer.
141        BitMask::from_int(mask.arch(), bitmask as u16)
142    }
143}
144
145// mask8x32
146#[derive(Debug, Clone, Copy)]
147#[allow(non_camel_case_types)]
148#[repr(transparent)]
149pub struct mask8x32(pub(crate) __m256i);
150
151impl SIMDMask for mask8x32 {
152    type Arch = V3;
153    type Underlying = __m256i;
154    type BitMask = BitMask<32, V3>;
155    const ISBITS: bool = false;
156    const LANES: usize = 32;
157
158    #[inline(always)]
159    fn arch(self) -> V3 {
160        // SAFETY: The existence of `self` is proof that we are V3 compatible.
161        unsafe { V3::new() }
162    }
163
164    #[inline(always)]
165    fn to_underlying(self) -> Self::Underlying {
166        self.0
167    }
168
169    #[inline(always)]
170    fn from_underlying(_: V3, value: Self::Underlying) -> Self {
171        Self(value)
172    }
173
174    #[inline(always)]
175    fn keep_first(_: V3, i: usize) -> Self {
176        let i = i.min(Self::LANES);
177        const CMP: [i8; 32] = [
178            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
179            24, 25, 26, 27, 28, 29, 30, 31,
180        ];
181
182        // SAFETY: The `V3` architecture instance is proof that we can use the V3 compatible
183        // intrinsics invoked used here.
184        //
185        // Unaligned load is valid because the array is the correct size.
186        //
187        // This constant local variable is hoisted to a constant in the final binary.
188        // Codegen emits a load to this value so it is relatively cheap to use.
189        unsafe {
190            let c = _mm256_loadu_si256(CMP.as_ptr() as *const __m256i);
191
192            // Broadcast the argument across all SIMD lanes, and compare the broadcasted register
193            // with the incremental value in `CMP`.
194            Self(_mm256_cmpgt_epi8(_mm256_set1_epi8(i as i8), c))
195        }
196    }
197
198    fn get_unchecked(&self, i: usize) -> bool {
199        // This is not particularly efficient.
200        // For bulk checking, users should first convert to a bit-mask and then check.
201        //
202        // Essentially, what this is doing is an entire conversion to a bit-mask to check
203        // a single lane.
204        Into::<Self::BitMask>::into(*self).get_unchecked(i)
205    }
206}
207
208// Conversion back-and-forth.
209// Credit to https://stackoverflow.com/a/72899629 for this algorithm.
210//
211// This follows the same strategy as `From<BitMask<16, V3>> for `mask8x16` - just twice as
212// wide.
213//
214// The 32-bit representation of `BitMask` is broadcast to all 8 lanes of a 256-bit wide
215// register. If we represent the 32-bit mask in terms of bytes like `b3b2b1b0`, then
216// following the broadcast, we get:
217// ```text
218// |  Lane 0  |  Lane 1  |  Lane 2  |  Lane 3  |  Lane 4  |  Lane 5  |  Lane 6  |  Lane 7  |
219// | b3b2b1b0 | b3b2b1b0 | b3b2b1b0 | b3b2b1b0 | b3b2b1b0 | b3b2b1b0 | b3b2b1b0 | b3b2b1b0 |
220// ```
221// Then, we shuffle to get
222// ```text
223// |  Lane 0  |  Lane 1  |  Lane 2  |  Lane 3  |  Lane 4  |  Lane 5  |  Lane 6  |  Lane 7  |
224// | b0b0b0b0 | b0b0b0b0 | b1b1b1b1 | b1b1b1b1 | b2b2b2b2 | b2b2b2b2 | b3b3b3b3 | b3b3b3b3 |
225// ```
226// From this position, we apply a bit mask to keep bit 0 of byte position 0 (`b0`) in lane 0,
227// bit 1 of byte position 1 (still `b0`) in lane 0 etc. In this way, we can isolate all the
228// bits in `mask` into bytes a `__m256i`. At which point, `_mm256_cmpeq_epi8` can be used
229// to test whether the bit is set or not and thus create the full mask.
230impl From<BitMask<32, V3>> for mask8x32 {
231    #[inline(always)]
232    fn from(mask: BitMask<32, V3>) -> Self {
233        // Extract the underlying integer.
234        let mask: u32 = mask.0;
235
236        // Masks used for the bit-twiddling.
237        // Select
238        // - bit 7 of byte 7
239        // - bit 6 of byte 6
240        // - bit 5 of byte 5
241        // etc.
242        const BIT_SELECTOR: i64 = 0x8040201008040201u64 as i64;
243
244        // Select byte 0 and broadcast it across 8 bytes.
245        const BROADCAST_BYTE_0: i64 = 0;
246        // Select byte 1 and broadcast it across 8 bytes;
247        const BROADCAST_BYTE_1: i64 = 0x0101010101010101;
248        // Select byte 2 and broadcast it across 8 bytes;
249        const BROADCAST_BYTE_2: i64 = 0x0202020202020202;
250        // Select byte 2 and broadcast it across 8 bytes;
251        const BROADCAST_BYTE_3: i64 = 0x0303030303030303;
252
253        // SAFETY: The `V3` architecture instance in `mask` is proof that we can use the V3
254        // compatible intrinsics invoked used here.
255        unsafe {
256            let selector = _mm256_set1_epi64x(BIT_SELECTOR);
257            Self(_mm256_cmpeq_epi8(
258                _mm256_and_si256(
259                    _mm256_shuffle_epi8(
260                        _mm256_set1_epi32(mask as i32),
261                        _mm256_set_epi64x(
262                            BROADCAST_BYTE_3,
263                            BROADCAST_BYTE_2,
264                            BROADCAST_BYTE_1,
265                            BROADCAST_BYTE_0,
266                        ),
267                    ),
268                    selector,
269                ),
270                selector,
271            ))
272        }
273    }
274}
275
276impl From<mask8x32> for BitMask<32, V3> {
277    #[inline(always)]
278    fn from(mask: mask8x32) -> Self {
279        let m = mask.to_underlying();
280        // Use an intrinsics to convert the upper bits to an integer bit-mask.
281        //
282        // SAFETY: `_mm256_movemask_epi8` requires AVX2 - which is implied by `V3`.
283        let bitmask: i32 = unsafe { _mm256_movemask_epi8(m) };
284        BitMask::from_int(mask.arch(), bitmask as u32)
285    }
286}
287
288// mask32x4
289#[derive(Debug, Clone, Copy)]
290#[allow(non_camel_case_types)]
291#[repr(transparent)]
292pub struct mask32x4(pub(crate) __m128i);
293
294impl SIMDMask for mask32x4 {
295    type Arch = V3;
296    type Underlying = __m128i;
297    type BitMask = BitMask<4, V3>;
298    const ISBITS: bool = false;
299    const LANES: usize = 4;
300
301    #[inline(always)]
302    fn arch(self) -> V3 {
303        // SAFETY: The existence of `Self` proves its architecture is safe.
304        unsafe { V3::new() }
305    }
306
307    #[inline(always)]
308    fn to_underlying(self) -> Self::Underlying {
309        self.0
310    }
311
312    #[inline(always)]
313    fn from_underlying(_: V3, value: Self::Underlying) -> Self {
314        Self(value)
315    }
316
317    #[inline(always)]
318    fn keep_first(_: V3, i: usize) -> Self {
319        let i = i.min(Self::LANES);
320        const CMP: [i32; 4] = [0, 1, 2, 3];
321
322        // SAFETY: This function is conditionally compiled only if the target platform
323        // contains the instruction set necessary for the intrinsics used here.
324        //
325        // Unaligned load is valid because the array is the correct size.
326        //
327        // This constant local variable is hoisted to a constant in the final binary.
328        // Codegen emits a load to this value so it is relatively cheap to use.
329        unsafe {
330            let c = _mm_loadu_si128(CMP.as_ptr() as *const __m128i);
331
332            // Broadcast the argument across all SIMD lanes, and compare the broadcasted register
333            // with the incremental value in `CMP`.
334            Self(_mm_cmpgt_epi32(_mm_set1_epi32(i as i32), c))
335        }
336    }
337
338    fn get_unchecked(&self, i: usize) -> bool {
339        // This is not particularly efficient.
340        // For bulk checking, users should first convert to a bit-mask and then check.
341        //
342        // Essentially, what this is doing is an entire conversion to a bit-mask to check
343        // a single lane.
344        Into::<Self::BitMask>::into(*self).get_unchecked(i)
345    }
346}
347
348// Conversion back-and-forth.
349impl From<BitMask<4, V3>> for mask32x4 {
350    #[inline(always)]
351    fn from(mask: BitMask<4, V3>) -> Self {
352        // Extract the underlying integer.
353        let mask: u8 = mask.0;
354        // SAFETY: Using intrinsics without touching memory.
355        // The trait implementation is conditional compiled on the intrinsics being
356        // available for the target platform.
357        unsafe {
358            let b = _mm_set1_epi32(mask as i32);
359            let cmp = _mm_set_epi32(8, 4, 2, 1);
360            let x = _mm_and_si128(b, cmp);
361            Self(_mm_cmpgt_epi32(x, _mm_setzero_si128()))
362        }
363    }
364}
365
366impl From<mask32x4> for BitMask<4, V3> {
367    #[inline(always)]
368    fn from(mask: mask32x4) -> Self {
369        let m = mask.to_underlying();
370        // Use an intrinsics to convert the upper bits to an integer bit-mask.
371        // SAFETY: Using intrinsics without touching memory. Invocation of the intrinsic
372        // is gated on successful check of the `cfg` macro.
373        let bitmask: i32 = unsafe { _mm_movemask_ps(_mm_castsi128_ps(m)) };
374        // The intrinsic only sets the lower-bit bits of the returned integer.
375        // We can safely truncate to an 8-bit integer.
376        BitMask::from_int(mask.arch(), bitmask as u8)
377    }
378}
379
380// mask32x8
381#[derive(Debug, Clone, Copy)]
382#[allow(non_camel_case_types)]
383#[repr(transparent)]
384pub struct mask32x8(pub(crate) __m256i);
385
386impl SIMDMask for mask32x8 {
387    type Arch = V3;
388    type Underlying = __m256i;
389    type BitMask = BitMask<8, V3>;
390    const ISBITS: bool = false;
391    const LANES: usize = 8;
392
393    #[inline(always)]
394    fn arch(self) -> V3 {
395        // SAFETY: The existence of `Self` proves its architecture is safe.
396        unsafe { V3::new() }
397    }
398
399    #[inline(always)]
400    fn to_underlying(self) -> Self::Underlying {
401        self.0
402    }
403
404    #[inline(always)]
405    fn from_underlying(_: V3, value: Self::Underlying) -> Self {
406        Self(value)
407    }
408
409    #[inline(always)]
410    fn keep_first(_: V3, i: usize) -> Self {
411        let i = i.min(Self::LANES);
412        // This kind of hurts my brain to look at.
413        const MASKS: [[u32; 8]; 9] = [
414            [0, 0, 0, 0, 0, 0, 0, 0],
415            [!0, 0, 0, 0, 0, 0, 0, 0],
416            [!0, !0, 0, 0, 0, 0, 0, 0],
417            [!0, !0, !0, 0, 0, 0, 0, 0],
418            [!0, !0, !0, !0, 0, 0, 0, 0],
419            [!0, !0, !0, !0, !0, 0, 0, 0],
420            [!0, !0, !0, !0, !0, !0, 0, 0],
421            [!0, !0, !0, !0, !0, !0, !0, 0],
422            [!0, !0, !0, !0, !0, !0, !0, !0],
423        ];
424
425        // const CMP: [i32; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
426
427        // SAFETY: This function is conditionally compiled only if the target platform
428        // contains the instruction set necessary for the intrinsics used here.
429        //
430        // Unaligned load is valid because the array is the correct size.
431        //
432        // This constant local variable is hoisted to a constant in the final binary.
433        // Codegen emits a load to this value so it is relatively cheap to use.
434
435        Self(unsafe { std::mem::transmute::<[u32; 8], __m256i>(MASKS[i]) })
436    }
437
438    fn get_unchecked(&self, i: usize) -> bool {
439        // This is not particularly efficient.
440        // For bulk checking, users should first convert to a bit-mask and then check.
441        //
442        // Essentially, what this is doing is an entire conversion to a bit-mask to check
443        // a single lane.
444        Into::<Self::BitMask>::into(*self).get_unchecked(i)
445    }
446}
447
448// Conversion back-and-forth.
449impl From<BitMask<8, V3>> for mask32x8 {
450    #[inline(always)]
451    fn from(mask: BitMask<8, V3>) -> Self {
452        // Extract the underlying integer.
453        let mask: u8 = mask.0;
454        // SAFETY: Using intrinsics without touching memory.
455        // Trait implementation gated on the intrinsic being available for the target
456        // platform.
457        unsafe {
458            let b = _mm256_set1_epi32(mask as i32);
459            let cmp = _mm256_set_epi32(128, 64, 32, 16, 8, 4, 2, 1);
460            let x = _mm256_and_si256(b, cmp);
461            Self(_mm256_cmpgt_epi32(x, _mm256_setzero_si256()))
462        }
463    }
464}
465
466impl From<mask32x8> for BitMask<8, V3> {
467    #[inline(always)]
468    fn from(mask: mask32x8) -> Self {
469        let m = mask.to_underlying();
470        // SAFETY: Using intrinsics without touching memory.
471        // Use an intrinsics to convert the upper bits to an integer bit-mask.
472        let bitmask: i32 = unsafe { _mm256_movemask_ps(_mm256_castsi256_ps(m)) };
473        // The intrinsic only sets the lower-bit bits of the returned integer.
474        // We can safely truncate to an 8-bit integer.
475        BitMask::from_int(mask.arch(), bitmask as u8)
476    }
477}
478
479///////////////////////
480// AVX2 64-bit masks //
481///////////////////////
482
483// mask64x2
484#[derive(Debug, Clone, Copy)]
485#[allow(non_camel_case_types)]
486#[repr(transparent)]
487pub struct mask64x2(pub(crate) __m128i);
488
489impl SIMDMask for mask64x2 {
490    type Arch = V3;
491    type Underlying = __m128i;
492    type BitMask = BitMask<2, V3>;
493    const ISBITS: bool = false;
494    const LANES: usize = 2;
495
496    #[inline(always)]
497    fn arch(self) -> V3 {
498        // SAFETY: The existence of `Self` proves its architecture is safe.
499        unsafe { V3::new() }
500    }
501
502    #[inline(always)]
503    fn to_underlying(self) -> Self::Underlying {
504        self.0
505    }
506
507    #[inline(always)]
508    fn from_underlying(_: V3, value: Self::Underlying) -> Self {
509        Self(value)
510    }
511
512    #[inline(always)]
513    fn keep_first(_: V3, i: usize) -> Self {
514        let i = i.min(Self::LANES);
515        // SAFETY: This function is conditionally compiled only if the target platform
516        // contains the instruction set necessary for the intrinsics used here.
517        //
518        // Unaligned load is valid because the array is the correct size.
519        //
520        // This constant local variable is hoisted to a constant in the final binary.
521        // Codegen emits a load to this value so it is relatively cheap to use.
522        unsafe {
523            const CMP: [i64; 2] = [0, 1];
524            let c = _mm_loadu_si128(CMP.as_ptr() as *const __m128i);
525
526            // Broadcast the argument across all SIMD lanes, and compare the broadcasted register
527            // with the incremental value in `CMP`.
528            Self(_mm_cmpgt_epi64(_mm_set1_epi64x(i as i64), c))
529        }
530    }
531
532    fn get_unchecked(&self, i: usize) -> bool {
533        // This is not particularly efficient.
534        // For bulk checking, users should first convert to a bit-mask and then check.
535        //
536        // Essentially, what this is doing is an entire conversion to a bit-mask to check
537        // a single lane.
538        Into::<Self::BitMask>::into(*self).get_unchecked(i)
539    }
540}
541
542// Conversion back-and-forth.
543impl From<BitMask<2, V3>> for mask64x2 {
544    #[inline(always)]
545    fn from(mask: BitMask<2, V3>) -> Self {
546        // Extract the underlying integer.
547        let mask: u8 = mask.0;
548        // SAFETY: Using intrinsics without touching memory.
549        // The trait implementation is conditional compiled on the intrinsics being
550        // available for the target platform.
551        unsafe {
552            let b = _mm_set1_epi64x(mask as i64);
553            let cmp = _mm_set_epi64x(2, 1);
554            let x = _mm_and_si128(b, cmp);
555            Self(_mm_cmpgt_epi64(x, _mm_setzero_si128()))
556        }
557    }
558}
559
560impl From<mask64x2> for BitMask<2, V3> {
561    #[inline(always)]
562    fn from(mask: mask64x2) -> Self {
563        let m = mask.to_underlying();
564        // Use an intrinsics to convert the upper bits to an integer bit-mask.
565        // SAFETY: Using intrinsics without touching memory. Invocation of the intrinsic
566        // is gated on successful check of the `cfg` macro.
567        let bitmask: i32 = unsafe { _mm_movemask_pd(_mm_castsi128_pd(m)) };
568        // The intrinsic only sets the lower-bit bits of the returned integer.
569        // We can safely truncate to an 8-bit integer.
570        BitMask::from_int(mask.arch(), bitmask as u8)
571    }
572}
573
574// mask64x4
575#[derive(Debug, Clone, Copy)]
576#[allow(non_camel_case_types)]
577#[repr(transparent)]
578pub struct mask64x4(pub(crate) __m256i);
579
580impl SIMDMask for mask64x4 {
581    type Arch = V3;
582    type Underlying = __m256i;
583    type BitMask = BitMask<4, V3>;
584    const ISBITS: bool = false;
585    const LANES: usize = 4;
586
587    #[inline(always)]
588    fn arch(self) -> V3 {
589        // SAFETY: The existence of `Self` proves its architecture is safe.
590        unsafe { V3::new() }
591    }
592
593    #[inline(always)]
594    fn to_underlying(self) -> Self::Underlying {
595        self.0
596    }
597
598    #[inline(always)]
599    fn from_underlying(_: V3, value: Self::Underlying) -> Self {
600        Self(value)
601    }
602
603    #[inline(always)]
604    fn keep_first(_: V3, i: usize) -> Self {
605        let i = i.min(Self::LANES);
606        // SAFETY: This function is conditionally compiled only if the target platform
607        // contains the instruction set necessary for the intrinsics used here.
608        //
609        // Unaligned load is valid because the array is the correct size.
610        //
611        // This constant local variable is hoisted to a constant in the final binary.
612        // Codegen emits a load to this value so it is relatively cheap to use.
613        unsafe {
614            const CMP: [i64; 4] = [0, 1, 2, 3];
615            let c = _mm256_loadu_si256(CMP.as_ptr() as *const __m256i);
616
617            // Broadcast the argument across all SIMD lanes, and compare the broadcasted register
618            // with the incremental value in `CMP`.
619            Self(_mm256_cmpgt_epi64(_mm256_set1_epi64x(i as i64), c))
620        }
621    }
622
623    fn get_unchecked(&self, i: usize) -> bool {
624        // This is not particularly efficient.
625        // For bulk checking, users should first convert to a bit-mask and then check.
626        //
627        // Essentially, what this is doing is an entire conversion to a bit-mask to check
628        // a single lane.
629        Into::<Self::BitMask>::into(*self).get_unchecked(i)
630    }
631}
632
633// Conversion back-and-forth.
634impl From<BitMask<4, V3>> for mask64x4 {
635    #[inline(always)]
636    fn from(mask: BitMask<4, V3>) -> Self {
637        // Extract the underlying integer.
638        let mask: u8 = mask.0;
639        // SAFETY: Using intrinsics without touching memory.
640        // The trait implementation is conditional compiled on the intrinsics being
641        // available for the target platform.
642        unsafe {
643            let b = _mm256_set1_epi64x(mask as i64);
644            let cmp = _mm256_set_epi64x(8, 4, 2, 1);
645            let x = _mm256_and_si256(b, cmp);
646            Self(_mm256_cmpgt_epi64(x, _mm256_setzero_si256()))
647        }
648    }
649}
650
651impl From<mask64x4> for BitMask<4, V3> {
652    #[inline(always)]
653    fn from(mask: mask64x4) -> Self {
654        let m = mask.to_underlying();
655        // Use an intrinsics to convert the upper bits to an integer bit-mask.
656        // SAFETY: Using intrinsics without touching memory. Invocation of the intrinsic
657        // is gated on successful check of the `cfg` macro.
658        let bitmask: i32 = unsafe { _mm256_movemask_pd(_mm256_castsi256_pd(m)) };
659        // The intrinsic only sets the lower-bit bits of the returned integer.
660        // We can safely truncate to an 8-bit integer.
661        BitMask::from_int(mask.arch(), bitmask as u8)
662    }
663}
664
665//////////////////
666// Double Masks //
667//////////////////
668
669// These mask definitions are shared across the double-wide implementations.
670
671// Native Masks
672doubled::double_mask!(64, mask8x32);
673doubled::double_mask!(16, mask32x8);
674
675// Bit Mask
676doubled::double_mask!(32, BitMask<16, V3>);
677
678#[cfg(test)]
679mod test_masks {
680    use rand::{Rng, SeedableRng};
681
682    use super::*;
683    use crate::{
684        Architecture, BitMask, Const, FromInt, SupportedLaneCount, doubled::Doubled, test_utils,
685        traits::SIMDMask,
686    };
687
688    trait TypeRange:
689        Copy + rand::distr::uniform::SampleUniform + std::cmp::PartialOrd + std::fmt::Display
690    {
691        fn make_range_() -> std::ops::RangeInclusive<Self>;
692    }
693
694    impl TypeRange for u8 {
695        fn make_range_() -> std::ops::RangeInclusive<Self> {
696            Self::MIN..=Self::MAX
697        }
698    }
699
700    impl TypeRange for u16 {
701        fn make_range_() -> std::ops::RangeInclusive<Self> {
702            Self::MIN..=Self::MAX
703        }
704    }
705
706    /// A trait to extract the top bit of an integer.
707    trait TopBit {
708        fn is_top_bit_set(&self) -> bool;
709    }
710
711    impl TopBit for u8 {
712        fn is_top_bit_set(&self) -> bool {
713            (self & 0x80) != 0
714        }
715    }
716
717    impl TopBit for u32 {
718        fn is_top_bit_set(&self) -> bool {
719            (self & 0x8000_0000) != 0
720        }
721    }
722
723    impl TopBit for u64 {
724        fn is_top_bit_set(&self) -> bool {
725            (self & 0x8000_0000_0000_0000) != 0
726        }
727    }
728
729    /// Trait to compare AVX2 masks with a corresponding bit-mask.
730    trait CheckWithBitmask {
731        type BitMask: SIMDMask;
732        /// Panics on a mismatch.
733        fn check(self, bitmask: Self::BitMask);
734    }
735
736    impl CheckWithBitmask for mask8x16 {
737        type BitMask = BitMask<16>;
738        fn check(self, bitmask: Self::BitMask) {
739            // Transmute the underlying register to the correct array.
740            //
741            // SAFETY: The two types are the same length, do not hold any resources, and
742            // are valid for all possible bit patterns.
743            let array = unsafe { std::mem::transmute::<__m128i, [u8; 16]>(self.to_underlying()) };
744            for (i, v) in array.iter().enumerate() {
745                assert_eq!(v.is_top_bit_set(), bitmask.get(i).unwrap());
746            }
747        }
748    }
749
750    impl CheckWithBitmask for mask8x32 {
751        type BitMask = BitMask<32>;
752        fn check(self, bitmask: Self::BitMask) {
753            // Transmute the underlying register to the correct array.
754            //
755            // SAFETY: The two types are the same length, do not hold any resources, and
756            // are valid for all possible bit patterns.
757            let array = unsafe { std::mem::transmute::<__m256i, [u8; 32]>(self.to_underlying()) };
758            for (i, v) in array.iter().enumerate() {
759                assert_eq!(v.is_top_bit_set(), bitmask.get(i).unwrap());
760            }
761        }
762    }
763
764    impl CheckWithBitmask for mask32x4 {
765        type BitMask = BitMask<4>;
766        fn check(self, bitmask: Self::BitMask) {
767            // Transmute the underlying register to the correct array.
768            //
769            // SAFETY: The two types are the same length, do not hold any resources, and
770            // are valid for all possible bit patterns.
771            let array = unsafe { std::mem::transmute::<__m128i, [u32; 4]>(self.to_underlying()) };
772            for (i, v) in array.iter().enumerate() {
773                assert_eq!(v.is_top_bit_set(), bitmask.get(i).unwrap());
774            }
775        }
776    }
777
778    impl CheckWithBitmask for mask32x8 {
779        type BitMask = BitMask<8>;
780        fn check(self, bitmask: Self::BitMask) {
781            // Transmute the underlying register to the correct array.
782            //
783            // SAFETY: The two types are the same length, do not hold any resources, and
784            // are valid for all possible bit patterns.
785            let array = unsafe { std::mem::transmute::<__m256i, [u32; 8]>(self.to_underlying()) };
786            for (i, v) in array.iter().enumerate() {
787                assert_eq!(v.is_top_bit_set(), bitmask.get(i).unwrap());
788            }
789        }
790    }
791
792    impl CheckWithBitmask for mask64x2 {
793        type BitMask = BitMask<2>;
794        fn check(self, bitmask: Self::BitMask) {
795            // Transmute the underlying register to the correct array.
796            //
797            // SAFETY: The two types are the same length, do not hold any resources, and
798            // are valid for all possible bit patterns.
799            let array = unsafe { std::mem::transmute::<__m128i, [u64; 2]>(self.to_underlying()) };
800            for (i, v) in array.iter().enumerate() {
801                assert_eq!(v.is_top_bit_set(), bitmask.get(i).unwrap());
802            }
803        }
804    }
805
806    impl CheckWithBitmask for mask64x4 {
807        type BitMask = BitMask<4>;
808        fn check(self, bitmask: Self::BitMask) {
809            // Transmute the underlying register to the correct array.
810            //
811            // SAFETY: The two types are the same length, do not hold any resources, and
812            // are valid for all possible bit patterns.
813            let array = unsafe { std::mem::transmute::<__m256i, [u64; 4]>(self.to_underlying()) };
814            for (i, v) in array.iter().enumerate() {
815                assert_eq!(v.is_top_bit_set(), bitmask.get(i).unwrap());
816            }
817        }
818    }
819
820    fn check_avx2_mask<T, const N: usize, A>(mask: T, bitmask: BitMask<N, A>)
821    where
822        A: Architecture,
823        Const<N>: SupportedLaneCount,
824        BitMask<N, A>: SIMDMask,
825        T: CheckWithBitmask<BitMask = BitMask<N>>,
826    {
827        mask.check(bitmask.as_current())
828    }
829
830    /// Test the conversion from bitmask to full-mask.
831    ///
832    /// Randomly generates bitmasks, constructs a full-mask, and ensures the `get` API
833    /// yields the same results.
834    ///
835    /// Also checks that converting the full-mask back to a bitmask is lossless.
836    fn test_mask_conversion_impl<T, const N: usize, A>(arch: A, num_trials: usize, seed: u64)
837    where
838        A: Architecture,
839        Const<N>: SupportedLaneCount,
840        BitMask<N, A>:
841            SIMDMask<Arch = A> + From<T> + FromInt<<BitMask<N, A> as SIMDMask>::Underlying, A>,
842        T: SIMDMask<Arch = A, BitMask = BitMask<N, A>> + From<BitMask<N, A>>,
843        <BitMask<N, A> as SIMDMask>::Underlying: TypeRange,
844    {
845        const MAXLEN: usize = 64;
846        assert_eq!(T::LANES, N);
847        assert!(MAXLEN >= T::LANES);
848
849        let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
850        for _ in 0..num_trials {
851            let u = rng.random_range(
852                <<BitMask<N, A> as SIMDMask>::Underlying as TypeRange>::make_range_(),
853            );
854
855            let bit_mask = BitMask::<N, A>::from_int(arch, u);
856            let full_mask: T = bit_mask.into();
857            for i in 0..=MAXLEN {
858                assert_eq!(bit_mask.get(i), full_mask.get(i));
859                assert_eq!(bit_mask.get_unchecked(i), full_mask.get_unchecked(i));
860            }
861
862            let from_full: BitMask<N, A> = full_mask.into();
863            assert_eq!(from_full, bit_mask);
864        }
865    }
866
867    #[test]
868    fn test_mask_conversion() {
869        if let Some(arch) = V3::new_checked_uncached() {
870            test_mask_conversion_impl::<mask8x16, 16, _>(arch, 5000, 0x12345);
871
872            test_mask_conversion_impl::<mask32x4, 4, _>(arch, 200, 0xc0ffee);
873            test_mask_conversion_impl::<mask32x8, 8, _>(arch, 1000, 0x7a08f5);
874
875            test_mask_conversion_impl::<mask64x4, 4, _>(arch, 32, 0x7a08f5);
876            test_mask_conversion_impl::<mask64x2, 2, _>(arch, 32, 0xc59783c5d8c4b59b);
877        }
878    }
879
880    #[test]
881    #[should_panic]
882    fn test_check_avx2_mask_panics_mask8x16() {
883        if let Some(arch) = V3::new_checked_uncached() {
884            let m = mask8x16::from_fn(arch, |i| i < 7);
885            let bm = BitMask::<16, V3>::from_fn(arch, |i| i < 10);
886            check_avx2_mask(m, bm);
887        } else {
888            panic!("skipping test due to architecture incompatibility");
889        }
890    }
891
892    #[test]
893    #[should_panic]
894    fn test_check_avx2_mask_panics_mask8x32() {
895        if let Some(arch) = V3::new_checked_uncached() {
896            let m = mask8x32::from_fn(arch, |i| i < 7);
897            let bm = BitMask::<32, V3>::from_fn(arch, |i| i < 10);
898            check_avx2_mask(m, bm);
899        } else {
900            panic!("skipping test due to architecture incompatibility");
901        }
902    }
903
904    #[test]
905    #[should_panic]
906    fn test_check_avx2_mask_panics_mask32x4() {
907        if let Some(arch) = V3::new_checked_uncached() {
908            let m = mask32x4::from_fn(arch, |i| i < 3);
909            let bm = BitMask::<4, V3>::from_fn(arch, |i| i <= 3);
910            check_avx2_mask(m, bm);
911        } else {
912            panic!("skipping test due to architecture incompatibility");
913        }
914    }
915
916    #[test]
917    #[should_panic]
918    fn test_check_avx2_mask_panics_mask32x8() {
919        if let Some(arch) = V3::new_checked_uncached() {
920            let m = mask32x8::from_fn(arch, |i| i < 7);
921            let bm = BitMask::<8, V3>::from_fn(arch, |i| i <= 7);
922            check_avx2_mask(m, bm);
923        } else {
924            panic!("skipping test due to architecture incompatibility");
925        }
926    }
927
928    #[test]
929    #[should_panic]
930    fn test_check_avx2_mask_panics_mask64x2() {
931        if let Some(arch) = V3::new_checked_uncached() {
932            let m = mask64x2::from_fn(arch, |i| i < 1);
933            let bm = BitMask::<2, V3>::from_fn(arch, |i| i <= 1);
934            check_avx2_mask(m, bm);
935        } else {
936            panic!("skipping test due to architecture incompatibility");
937        }
938    }
939
940    #[test]
941    #[should_panic]
942    fn test_check_avx2_mask_panics_mask64x4() {
943        if let Some(arch) = V3::new_checked_uncached() {
944            let m = mask64x4::from_fn(arch, |i| i < 3);
945            let bm = BitMask::<4, V3>::from_fn(arch, |i| i <= 3);
946            check_avx2_mask(m, bm);
947        } else {
948            panic!("skipping test due to architecture incompatibility");
949        }
950    }
951
952    // Helper macro to run the AVX2 masks through the SIMDMask test routines.
953    macro_rules! test_simdmask {
954        ($mask:ident $(< $($ps:tt),+ >)?, $N:literal, $checker:expr) => {
955            paste::paste! {
956                #[test]
957                fn [<test_simd_mask_ $mask:lower $(_$($ps:lower )x+)? x $N>]() {
958                    type T = $mask $(< $($ps),+>)?;
959                    if let Some(arch) = V3::new_checked_uncached() {
960                        test_utils::mask::test_keep_first::<T, $N, _, _>(arch, $checker);
961                        test_utils::mask::test_from_fn::<T, $N, _, _>(arch, $checker);
962                        test_utils::mask::test_reductions::<T, $N, _, _>(arch, $checker);
963                        test_utils::mask::test_first::<T, $N, _, _>(arch, $checker);
964                    }
965                }
966            }
967        };
968    }
969
970    test_simdmask!(mask8x16, 16, check_avx2_mask);
971    test_simdmask!(mask8x32, 32, check_avx2_mask);
972
973    test_simdmask!(mask32x4, 4, check_avx2_mask);
974    test_simdmask!(mask32x8, 8, check_avx2_mask);
975
976    test_simdmask!(mask64x2, 2, check_avx2_mask);
977    test_simdmask!(mask64x4, 4, check_avx2_mask);
978
979    fn nop<T, const N: usize, A>(_: T, _: BitMask<N, A>)
980    where
981        A: crate::arch::Sealed,
982        Const<N>: SupportedLaneCount,
983    {
984    }
985
986    // Double
987    test_simdmask!(Doubled<mask8x32>, 64, nop);
988    test_simdmask!(Doubled<mask32x8>, 16, nop);
989
990    // Type alias to work around limitations in `test_simdmask`.
991    type BitMask16V3 = BitMask<16, V3>;
992    test_simdmask!(Doubled<BitMask16V3>, 32, nop);
993}