Skip to main content

diskann_wide/
traits.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::{
7    arch,
8    bitmask::{BitMask, FromInt},
9    constant::{Const, SupportedLaneCount},
10};
11
12/// Rust currently lacks the ability to use a trait's associated constants as constraints
13/// or constant parameters to other type definitions within the same trait.
14///
15/// The use of the `Const` type moves the const generic into the type domain, allowing us
16/// to properly constrain and define other associated items.
17///
18/// In particular, we currently can't do nice things like:
19/// ```ignore
20/// trait MyTrait {
21///     const MY_CONSTANT: usize;
22///     type ArrayType = [f32; Self::MY_CONSTANT];
23///                            ----------------- Currently unsupported.
24/// }
25/// ```
26///
27/// See:
28/// - <https://users.rust-lang.org/t/limitation-of-associated-const-in-traits/73491/2>
29/// - <https://github.com/rust-lang/rust/issues/76560>
30pub trait ArrayType<T>: SupportedLaneCount {
31    type Type;
32}
33
34/// Map scalar + lengths to arrays.
35impl<T, const N: usize> ArrayType<T> for Const<N>
36where
37    Const<N>: SupportedLaneCount,
38{
39    type Type = [T; N];
40}
41
42/// Stable Rust does not allow expressions involving compile-time computation with
43/// const generic parameters: <http://github.com/rust-lang/rust/issues/76560s>
44///
45/// This makes is difficult to go a const parameter defining the number of SIMD lanes
46/// to an appropriately sized mask.
47///
48/// This helper trait provides a level of indirection to map SIMD representations to
49/// the associated bitmask.
50pub trait BitMaskType<A: arch::Sealed>: SupportedLaneCount {
51    type Type; // should always be a `BitMask`.
52}
53
54impl<A, const N: usize> BitMaskType<A> for Const<N>
55where
56    Const<N>: SupportedLaneCount,
57    A: arch::Sealed,
58{
59    type Type = BitMask<N, A>;
60}
61
62/// Convert `Self` to the SIMD type `T`. This is mainly useful when implementing fallback
63/// operations through [`crate::Emulated`] to restore the original SIMD type.
64pub trait AsSIMD<T>: Copy
65where
66    T: SIMDVector,
67{
68    fn as_simd(self, arch: T::Arch) -> T;
69}
70
71/// A logical mask for SIMD operations.
72///
73/// The representation of this type varies between architectures and micro-architectures.
74/// For example:
75///
76/// * On AVX 2 systems, a SIMD mask for type/length pairs `(T, N)` consists of a SIMD
77///   register of an unsigned integer with the same size of `T` and length `N`.
78///
79///   The semantics of such registers are to allow operations in lanes where the top-most
80///   bit is set to 1.
81///
82/// * On AVX-512 systems, the story is much simpler as the masks used in that instruction
83///   set are simply the correponsing bit mask.
84///
85///   So a mask for 8-wide operations is simply an 8-bit unsigned integer.
86///
87/// * Emulated systems should use a bit-mask for the most compact representation.
88pub trait SIMDMask: Copy + std::fmt::Debug {
89    /// The architecture type this struct belongs to.
90    type Arch: arch::Sealed;
91
92    /// The type of the underlying intrinsic.
93    type Underlying: Copy + std::fmt::Debug;
94
95    /// The bitmask associated with the logical mask.
96    type BitMask: SIMDMask<Arch = Self::Arch> + Into<Self> + From<Self>;
97
98    /// The number of lanes in the bitmask.
99    const LANES: usize;
100
101    /// Whether or not this mask implementation is a bit mask.
102    const ISBITS: bool;
103
104    /// Return the architecture object associated with this vector.
105    fn arch(self) -> Self::Arch;
106
107    /// Retrieve the underlying type.
108    /// This will always be an unsigned integer of the minimum width required to contain
109    /// `LANES` bits.
110    fn to_underlying(self) -> Self::Underlying;
111
112    /// Construct the mask from the underlying type.
113    fn from_underlying(arch: Self::Arch, value: Self::Underlying) -> Self;
114
115    /// Return `true` if lane `i` is set and `false` otherwise.
116    ///
117    /// This method is unchecked, but safe in the sense that if `i >= LANES` false will
118    /// always be returned. No out of bounds access will be made, but no error indication
119    /// will be provided.
120    fn get_unchecked(&self, i: usize) -> bool;
121
122    /// Efficiently construct a new mask with the first `i` bits set and the remainder
123    /// set to zero.
124    ///
125    /// If `i >= LANES` then all bits will be set.
126    fn keep_first(arch: Self::Arch, i: usize) -> Self;
127
128    /// Return the first set index in the mask or `None` if no entries are set.
129    fn first(&self) -> Option<usize> {
130        self.bitmask().first()
131    }
132
133    //////////////////////////////
134    // Provided Implementations //
135    //////////////////////////////
136
137    /// Return the associated BitMask for this Mask.
138    fn bitmask(self) -> Self::BitMask {
139        <Self::BitMask as From<Self>>::from(self)
140    }
141
142    /// Return `true` if lane `i` is set and `false` otherwise. Returns an empty `Option`
143    /// if the index `i` is out-of-bounds.
144    fn get(&self, i: usize) -> Option<bool> {
145        if i >= Self::LANES {
146            None
147        } else {
148            Some(self.get_unchecked(i))
149        }
150    }
151
152    /// Construct a mask based on the result of invoking `f` once each element in the range
153    /// `0..Self::LANES` in order.
154    ///
155    /// In the returned mask `m`, `m.get(0)` corresponds to the value of `f(0)`. Similarly,
156    /// `m.get(1)` corresponds to `f(1)` etc.
157    #[inline(always)]
158    fn from_fn<F>(arch: Self::Arch, f: F) -> Self
159    where
160        F: FnMut(usize) -> bool,
161    {
162        // Recurse to BitMask.
163        Self::BitMask::from_fn(arch, f).into()
164    }
165
166    /// Return `true` if any lane in the mask is set. Otherwise, return `false.
167    #[inline(always)]
168    fn any(self) -> bool {
169        // Recurse to BitMask.
170        <Self::BitMask as From<Self>>::from(self).any()
171    }
172
173    /// Return `true` if all lanes in the mask are set. Otherwise, return `false`.
174    #[inline(always)]
175    fn all(self) -> bool {
176        // Recurse to BitMask.
177        <Self::BitMask as From<Self>>::from(self).all()
178    }
179
180    /// Return `true` if all lanes in the mask are set. Otherwise, return `false`.
181    #[inline(always)]
182    fn none(self) -> bool {
183        !self.any()
184    }
185
186    /// Return the number of lanes that evaluate to `true`.
187    #[inline(always)]
188    fn count(self) -> usize {
189        // Recurse to BitMask.
190        <Self::BitMask as From<Self>>::from(self).count()
191    }
192}
193
194/// A trait representing minimal behavior for a SIMD-like vector.
195///
196/// A SIMDVector can be thought of as a homogeneous array `[T; N]` (with potentially
197/// stricter alignment requirements) that generally behave for arithmetic purposes like
198/// scalars in the sense that if
199/// ```ignore
200/// fn add(a: V, b: V) -> V
201/// where V: SIMDVector {
202///     a + b
203/// }
204/// ```
205/// will have the same semantics of broadcasting the `+` operation across all lanes in the
206/// vector.
207pub trait SIMDVector: Copy + std::fmt::Debug {
208    /// The architecture this vector belongs to.
209    type Arch: arch::Sealed;
210
211    /// The type of each element in the vector.
212    type Scalar: Copy + std::fmt::Debug;
213
214    /// The underlying representation.
215    type Underlying: Copy;
216
217    /// The number of lanes in the vector.
218    const LANES: usize;
219
220    /// The value of `LANES` but in the type domain so we can use it to constrain other
221    /// aspects of this trait.
222    ///
223    /// Should be the type `Const<Self::LANES>`.
224    type ConstLanes: ArrayType<Self::Scalar> + BitMaskType<Self::Arch>;
225
226    /// The expanded logical mask representation.
227    /// This may-or-may-not actually be a bitmask, but should be easily convertible to and
228    /// from a bitmask.
229    type Mask: SIMDMask<Arch = Self::Arch>
230        + From<<Self::ConstLanes as BitMaskType<Self::Arch>>::Type>
231        + Into<<Self::ConstLanes as BitMaskType<Self::Arch>>::Type>;
232
233    /// Whether or not this is an emulated vector.
234    ///
235    /// Emulated vectors are backed by Rust arrays and use scalar loops to implement
236    /// arithmetic operations.
237    const EMULATED: bool;
238
239    /// Return the architecture object associated with this vector.
240    ///
241    /// # NOTE
242    ///
243    /// This is safe because construction of `self` serves as the witness that we are on
244    /// a compatible architecture.
245    fn arch(self) -> Self::Arch;
246
247    /// Return the default value for the type. This is always the numberic 0 for the
248    /// associated scalar type.
249    fn default(arch: Self::Arch) -> Self;
250
251    /// Return the underlying type.
252    fn to_underlying(self) -> Self::Underlying;
253
254    /// Construct from the underlying type.
255    fn from_underlying(arch: Self::Arch, repr: Self::Underlying) -> Self;
256
257    /// Retrieve the contents as an array.
258    fn to_array(self) -> <Self::ConstLanes as ArrayType<Self::Scalar>>::Type;
259
260    /// Construct from the associated array.
261    ///
262    /// The argument `arch` provides a "proof of compatibility" as `A` can only be safely
263    /// instantiated when all the requirements for the architecture are met.
264    fn from_array(arch: Self::Arch, x: <Self::ConstLanes as ArrayType<Self::Scalar>>::Type)
265    -> Self;
266
267    /// Broadcast the provided scalar across all lanes.
268    ///
269    /// The argument `arch` provides a "proof of compatibility" as `A` can only be safely
270    /// instantiated when all the requirements for the architecture are met.
271    fn splat(arch: Self::Arch, value: Self::Scalar) -> Self;
272
273    /// Return the number of lanes in this vector.
274    fn num_lanes() -> usize {
275        Self::LANES
276    }
277
278    /// Load `<Self as SIMDVector>::LANES` number of elements starting at the provided
279    /// pointer.
280    ///
281    /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
282    /// not need to be stricter.
283    ///
284    /// # Safety
285    ///
286    /// A contiguous read of `<Self as SIMDVector>::LANES` must touch valid memory.
287    unsafe fn load_simd(arch: Self::Arch, ptr: *const <Self as SIMDVector>::Scalar) -> Self;
288
289    /// Load `<Self as SIMDVector>::LANES` number of elements starting at the provided
290    /// pointer.
291    ///
292    /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
293    /// not need to be stricter.
294    ///
295    /// Entries in the mask that evaluate to `false` will not be accessed.
296    /// This makes it safe to use this function with lanes masked out that would otherwise
297    /// cross a page boundary or otherwise cause an out-of-bounds read.
298    ///
299    /// # Safety
300    ///
301    /// Offsets from the `ptr` where the mask evaluates to true must be dereferencable to
302    /// the underlying scalar type.
303    unsafe fn load_simd_masked_logical(
304        arch: Self::Arch,
305        ptr: *const <Self as SIMDVector>::Scalar,
306        mask: <Self as SIMDVector>::Mask,
307    ) -> Self;
308
309    /// The same as `load_simd_masked_logical` but taking a BitMask instead.
310    ///
311    /// No load attempt will be made to lanes that are masked out.
312    ///
313    /// # Safety
314    ///
315    /// Offsets from the `ptr` where the mask evaluates to true must be dereferencable to
316    /// the underlying scalar type. For implementations using the provided default, the
317    /// conversion from the bitmask to the actual mask must be correct.
318    #[inline(always)]
319    unsafe fn load_simd_masked(
320        arch: Self::Arch,
321        ptr: *const <Self as SIMDVector>::Scalar,
322        mask: <<Self as SIMDVector>::ConstLanes as BitMaskType<Self::Arch>>::Type,
323    ) -> Self {
324        // SAFETY: Bitmasks must be convertible to their corresponding logical mask.
325        // When the logical mask **is** a bitbask, this is a no-op.
326        unsafe { Self::load_simd_masked_logical(arch, ptr, mask.into()) }
327    }
328
329    /// The same as `load_simd_masked_logical`, but potentially specialized for situations
330    /// where it is known that some number of first elements will be accessed.
331    ///
332    /// If `first` is greater than or equal to the number of lanes, then all lanes will be
333    /// loaded.
334    ///
335    /// # Safety
336    ///
337    /// A contiguous read of `first.min(<Self as SIMDVector>::LANES)` must be valid.
338    #[inline(always)]
339    unsafe fn load_simd_first(
340        arch: Self::Arch,
341        ptr: *const <Self as SIMDVector>::Scalar,
342        first: usize,
343    ) -> Self {
344        // SAFETY: The implementation of `SIMDMask` must be correct.
345        unsafe {
346            Self::load_simd_masked_logical(
347                arch,
348                ptr,
349                <Self as SIMDVector>::Mask::keep_first(arch, first),
350            )
351        }
352    }
353
354    /// Store `<Self as SIMDVector>::LANES` number of elements contiguously starting at the
355    /// provided pointer.
356    ///
357    /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
358    /// not need to be stricter.
359    ///
360    /// # Safety
361    ///
362    /// The pointed-to memory must adhere to Rust's exclusive reference rules.
363    ///
364    /// A contiguous store of `<Self as SIMDVector>::LANES` must touch valid memory.
365    unsafe fn store_simd(self, ptr: *mut <Self as SIMDVector>::Scalar);
366
367    /// Store `<Self as SIMDVector>::LANES` number of elements starting at the provided
368    /// pointer.
369    ///
370    /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
371    /// not need to be stricter.
372    ///
373    /// Entries in the mask that evaluate to `false` will not be accessed.
374    /// This makes it safe to use this function with lanes masked out that would otherwise
375    /// cross a page boundary or otherwise cause an out-of-bounds write.
376    ///
377    /// # Safety
378    ///
379    /// The pointed-to memory must adhere to Rust's exclusive reference rules.
380    ///
381    /// Offsets from the `ptr` where the mask evaluates to true must be mutably
382    /// dereferencable to the underlying scalar type.
383    unsafe fn store_simd_masked_logical(
384        self,
385        ptr: *mut <Self as SIMDVector>::Scalar,
386        mask: <Self as SIMDVector>::Mask,
387    );
388
389    /// The same as `load_simd_masked_logical` but taking a BitMask instead.
390    ///
391    /// No store attempt will be made to lanes that are masked out.
392    ///
393    /// # Safety
394    ///
395    /// The pointed-to memory must adhere to Rust's exclusive reference rules.
396    ///
397    /// Offsets from the `ptr` where the mask evaluates to true must be mutably
398    /// dereferencable to the underlying scalar type.
399    ///
400    /// For implementations using the provided default, the conversion from the bitmask to
401    /// the actual mask must be correct.
402    #[inline(always)]
403    unsafe fn store_simd_masked(
404        self,
405        ptr: *mut <Self as SIMDVector>::Scalar,
406        mask: <<Self as SIMDVector>::ConstLanes as BitMaskType<Self::Arch>>::Type,
407    ) {
408        // SAFETY: Bitmasks must be convertible to their corresponding logical mask.
409        // When the logical mask **is** a bitbask, this is a no-op.
410        unsafe { self.store_simd_masked_logical(ptr, mask.into()) }
411    }
412
413    /// The same as `store_simd_masked_logical`, but potentially specialized for situations
414    /// where it is known that some number of first elements will be accessed.
415    ///
416    /// If `first` is greater than or equal to the number of lanes, then all lanes will be
417    /// written.
418    ///
419    /// # Safety
420    ///
421    /// The pointed-to memory must adhere to Rust's exclusive reference rules.
422    ///
423    /// A contiguous write of `first.min(<Self as SIMDVector>::LANES)` must be valid.
424    #[inline(always)]
425    unsafe fn store_simd_first(self, ptr: *mut <Self as SIMDVector>::Scalar, first: usize) {
426        // SAFETY: The implementation of `SIMDMask` must be correct.
427        unsafe {
428            self.store_simd_masked_logical(
429                ptr,
430                <Self as SIMDVector>::Mask::keep_first(self.arch(), first),
431            )
432        }
433    }
434
435    /// Perform a numeric cast on each element, returning a new SIMD vector.
436    ///
437    /// See also: [`SIMDCast`].
438    #[inline(always)]
439    fn cast<T>(self) -> <Self as SIMDCast<T>>::Cast
440    where
441        Self: SIMDCast<T>,
442    {
443        self.simd_cast()
444    }
445}
446
447/// Efficiently perform the operation
448/// ```ignore
449/// self * rhs + accumulator
450/// ```
451/// with the following semantics dependant on the associated scalar type.
452///
453/// * floating point: Perform a fused multiply-add, implementing the operation with only a
454///   single rounding instance.
455///
456/// * integer: Perform the multiplication followed by the accumulation. Both binary
457///   operations will be performed using wrap-around arithmetic.
458pub trait SIMDMulAdd {
459    fn mul_add_simd(self, rhs: Self, accumulator: Self) -> Self;
460}
461
462/// Efficiently retrieve the pairwise minimum or maximum for the two arguments.
463///
464/// Each function comes in two flavors:
465///
466/// * Standard (suffixed): Compute the minimum or maximum in a way that is equivalent to
467///   Rust's built-in minimum or maximum functions.
468///
469///   When the scalar type is integral, the behavior is unambiguous.
470///
471///   When the scalar type is a floating point and one value of a pair is NaN, the other
472///   value is returned. When the result is zero, either a positive or a negative zero can
473///   be returned.
474///
475/// * Fast (unsuffixed): Compute the minimum or maximum using the fastest possible method
476///   on the given architecture with non-standard NaN handing.
477///
478///   When the scalar type is integral, the behavior is the same as the standard
479///   implementations.
480///
481///   When the scalar type is a floating point, the implementation is allowed to differ
482///   with respect to NaN handling. That is, when one of the arguments is NaN, the
483///   implementation is allowed to return **either** the other argument (like the standard
484///   implementation) or NaN. Like the standard implementation, if the result is zero, then
485///   a zero of either sign can be returned.
486///
487///   This method should be preferred when precise NaN handling is not needed as it can be
488///   more efficient.
489pub trait SIMDMinMax: Sized {
490    /// Return the pairwise minimum of `self` and `rhs`, subject to looser NaN handling.
491    fn min_simd(self, rhs: Self) -> Self;
492
493    /// Return the pairwise minimum of `self` and `rhs` as if by applying the standard
494    /// library's `min` method for the scalar type.
495    #[inline(always)]
496    fn min_simd_standard(self, rhs: Self) -> Self {
497        self.min_simd(rhs)
498    }
499
500    /// Return the pairwise maximum of `self` and `rhs`, subject to looser NaN handling.
501    fn max_simd(self, rhs: Self) -> Self;
502
503    /// Return the pairwise maximum of `self` and `rhs` as if by applying the standard
504    /// library's `max` method for the scalar type.
505    #[inline(always)]
506    fn max_simd_standard(self, rhs: Self) -> Self {
507        self.max_simd(rhs)
508    }
509}
510
511/// Take the absolute value of each lane.
512///
513/// # Notes
514///
515/// For signed integer types T, this works as expected for all values except for `T::MIN`,
516/// in which case `T::MIN` is returned. This keeps the behavior in line with hardware
517/// intrinsics.
518///
519/// A correct answer can be retrieved by casting the result to the equivalent unsigned
520/// integer.
521pub trait SIMDAbs {
522    fn abs_simd(self) -> Self;
523}
524
525/// A SIMD equivalent of `std::cmp::PartialEq`.
526///
527/// Instead of a boolean, return `Self::Mask` containin the result of the element-wise
528/// comparison of the two vectors.
529pub trait SIMDPartialEq: SIMDVector {
530    /// SIMD equivalent of `std::cmp::PartialEq::eq`, applying the latter trait to each
531    /// lane-wise pair of elements in `self` and `other`.
532    fn eq_simd(self, other: Self) -> Self::Mask;
533
534    /// SIMD equivalent of `std::cmp::PartialEq::neq`, applying the latter trait to each
535    /// lane-wise pair of elements in `self` and `other`.
536    fn ne_simd(self, other: Self) -> Self::Mask;
537}
538
539/// A SIMD equaivalent of `std::cmp::PartialOrd`.
540///
541/// Instead of a boolean, return `Self::Mask` containing the result of the element-wise
542/// comparisons of the two vectors.
543pub trait SIMDPartialOrd: SIMDVector {
544    /// SIMD equivalent of `std::cmp::PartialOrd::lt`.
545    fn lt_simd(self, other: Self) -> Self::Mask;
546
547    /// SIMD equivalent of `std::cmp::PartialOrd::le`.
548    fn le_simd(self, other: Self) -> Self::Mask;
549
550    //////////////////////
551    // Provided Methods //
552    //////////////////////
553
554    /// SIMD equivalent of `std::cmp::PartialOrd::gt`.
555    ///
556    /// Types are free to override the provided method if a more efficient implementation
557    /// is possible.
558    #[inline(always)]
559    fn gt_simd(self, other: Self) -> Self::Mask {
560        other.lt_simd(self)
561    }
562
563    /// SIMD equivalent of `std::cmp::PartialOrd::ge`.
564    ///
565    /// Types are free to override the provided method if a more efficient implementation
566    /// is possible.
567    #[inline(always)]
568    fn ge_simd(self, other: Self) -> Self::Mask {
569        other.le_simd(self)
570    }
571}
572
573/// Perform a pairwise reducing sum of all lanes in the vector and return the result as a
574/// scalar.
575///
576/// For example, the summing pattern for a vector of 8 elements is as follows:
577/// ```text
578/// let v0 = [x0, x1, x2, x3, x4, x5, x6, x7];
579/// let v1 = [v0[0] + v0[4], v0[1] + v0[5], v0[2] + v0[6], v0[3] + v0[7]];
580/// let v2 = [v1[0] + v1[2], v1[1] + v1[3]];
581/// v2[0] + v2[1]
582/// ```
583pub trait SIMDSumTree: SIMDVector {
584    fn sum_tree(self) -> <Self as SIMDVector>::Scalar;
585}
586
587/// A vectorized "if else".
588pub trait SIMDSelect<V: SIMDVector>: SIMDMask {
589    fn select(self, x: V, y: V) -> V;
590}
591
592/// Optimized dot-product style accumulation.
593///
594/// This tries to match against intrinsics like:
595///
596/// * `_mm256_madd_epi16`
597/// * `_mm256_dpbusd_epi32`
598///
599/// The gist is to perform element-wise multiplication between left and right, promoting the
600/// result to the element-type of `Self`, adding adjacent entries
601///
602/// # Precise Enumeration of Semantics for Implementations
603///
604/// The semantics depend on the source and destination type, but are intended to be the same
605/// for each type combination across architectures.
606///
607/// ## `SIMDDotProduct<i16x16> for i32x8`
608///
609/// 1. Perform multiplication as `i16x16 x i16x16` as if converting each lane to `i32`,
610///    resulting in effectively `i32x16`. No overflow can happen.
611/// 2. Add together adjacent pairs in the resulting `i32x16` to yield `i32x8`. Again, this
612///    step cannot overflow.
613/// 3. Add the resulting `i32x8` into `Self`, returning the result.
614///
615/// ## `SIMDDotProduct<u8x32, i8x32> for i32x8`
616///
617/// 1. Perform multiplication as `i32x32 x i32x32` as if converting each lane to `i32`,
618///    resulting in effectively `i32x32`. No overflow can happen.
619/// 2. Sum together consecutive groups of 4 in the resulting `i32x32` to yield `i32x8`.
620/// 3. Add the resulting `i342x8` into `Self`.
621///
622/// The same applies when the order of `u8x32` and `i8x32` are swapped and for types that
623/// are twice as wide.
624///
625/// The main goal of this function is to hit VNNI instructions like `_mm512_dpbusd_epi32`
626/// that can do the whole operation in a single go on the `V4` architecture. Use of this
627/// instruction is not recommended on non-`V4` architectures.
628pub trait SIMDDotProduct<L: SIMDVector, R: SIMDVector = L> {
629    /// Element wise multiply each component of `left` and `right`, promoting the
630    /// intermediate results to a higher precision.
631    ///
632    /// Then, horizontally add together groups of the accumulated values and add the
633    /// resulting sums to `self`.
634    ///
635    /// The size of the group depends on the relative number of lanes in `Self` and `Source`.
636    ///
637    /// However, it is required that `Self::num_lanes()` evenly divides `Source::num_lanes()`
638    /// so that the size of each group is uniform.
639    fn dot_simd(self, left: L, right: R) -> Self;
640}
641
642/// Perform a bit-cast from one SIMD type to another.
643pub trait SIMDReinterpret<To: SIMDVector>: SIMDVector {
644    fn reinterpret_simd(self) -> To;
645}
646
647/// Perform a numeric cast on the scalar type.
648///
649/// Unlike `From`, this conversion is allowed to be lossy, with similar semantics to
650/// numeric casts in scalar Rust.
651///
652/// This is meant to model Rust's numeric conversion with the "as" operator.
653pub trait SIMDCast<T>: SIMDVector {
654    /// The [`SIMDVector`] type of the result.
655    type Cast: SIMDVector<Scalar = T, ConstLanes = Self::ConstLanes>;
656    /// Perform the cast.
657    fn simd_cast(self) -> Self::Cast;
658}
659
660/// A roll-up of traits required for SIMD floating point types.
661pub trait SIMDFloat:
662    SIMDVector
663    + std::ops::Add<Output = Self>
664    + std::ops::Mul<Output = Self>
665    + std::ops::Sub<Output = Self>
666    + SIMDMulAdd
667    + SIMDMinMax
668    + SIMDPartialEq
669    + SIMDPartialOrd
670{
671}
672
673impl<T> SIMDFloat for T where
674    T: SIMDVector
675        + std::ops::Add<Output = Self>
676        + std::ops::Mul<Output = Self>
677        + std::ops::Sub<Output = Self>
678        + SIMDMulAdd
679        + SIMDMinMax
680        + SIMDPartialEq
681        + SIMDPartialOrd
682{
683}
684
685/// A roll-up of traits required for SIMD integer types.
686pub trait SIMDUnsigned:
687    SIMDVector
688    + std::ops::Add<Output = Self>
689    + std::ops::Mul<Output = Self>
690    + std::ops::Sub<Output = Self>
691    + std::ops::BitAnd<Output = Self>
692    + std::ops::BitOr<Output = Self>
693    + std::ops::BitXor<Output = Self>
694    + std::ops::Shr<Output = Self>
695    + std::ops::Shl<Output = Self>
696    + std::ops::Shr<Self::Scalar, Output = Self>
697    + std::ops::Shl<Self::Scalar, Output = Self>
698    + SIMDMulAdd
699    + SIMDPartialEq
700    + SIMDPartialOrd
701{
702}
703
704impl<T> SIMDUnsigned for T where
705    T: SIMDVector
706        + std::ops::Add<Output = Self>
707        + std::ops::Mul<Output = Self>
708        + std::ops::Sub<Output = Self>
709        + std::ops::BitAnd<Output = Self>
710        + std::ops::BitOr<Output = Self>
711        + std::ops::BitXor<Output = Self>
712        + std::ops::Shr<Output = Self>
713        + std::ops::Shl<Output = Self>
714        + std::ops::Shr<Self::Scalar, Output = Self>
715        + std::ops::Shl<Self::Scalar, Output = Self>
716        + SIMDMulAdd
717        + SIMDPartialEq
718        + SIMDPartialOrd
719{
720}
721
722pub trait SIMDSigned: SIMDUnsigned + SIMDAbs {}
723impl<T> SIMDSigned for T where T: SIMDUnsigned + SIMDAbs {}
724
725// Since it is so difficult to work directly with generic integers, resort to using a macro
726// to stamp out implementations of `SIMDMask` for `BitMask`.
727//
728// The argument `submask` is a bit-pattern to apply to the underlying type and is to mask
729// out upper-bits of the representation for 2 and 4 bit masks.
730macro_rules! impl_simd_mask_for_bitmask {
731    ($N:literal, $repr:ty, $submask:expr) => {
732        impl<A: arch::Sealed> SIMDMask for BitMask<$N, A> {
733            type Arch = A;
734            type Underlying = $repr;
735            type BitMask = Self;
736            const ISBITS: bool = true;
737            const LANES: usize = $N;
738
739            #[inline(always)]
740            fn arch(self) -> A {
741                self.get_arch()
742            }
743
744            #[inline(always)]
745            fn to_underlying(self) -> Self::Underlying {
746                self.0
747            }
748
749            #[inline(always)]
750            fn from_underlying(arch: A, value: Self::Underlying) -> Self {
751                Self::from_int(arch, value)
752            }
753
754            #[inline(always)]
755            fn keep_first(arch: A, i: usize) -> Self {
756                // Ensure that providing a value that is too big still yields sensible
757                // results.
758                let i = i.min(Self::LANES);
759
760                // Handle 64-bit integers properly.
761                // It is expected that the compiler will be able to optimize out this branch
762                // for non-64-bit types.
763                if Self::LANES == 64 && i == 64 {
764                    return Self::from_underlying(arch, Self::Underlying::MAX);
765                }
766
767                let one: u64 = 1;
768                // "as" conversion in Rust performs truncation on the integers.
769                Self::from_underlying(arch, ((one << i) - one) as Self::Underlying)
770            }
771
772            #[inline(always)]
773            fn get_unchecked(&self, i: usize) -> bool {
774                if i >= Self::LANES {
775                    false
776                } else {
777                    (self.0 >> i) % 2 == 1
778                }
779            }
780
781            #[inline(always)]
782            fn first(&self) -> Option<usize> {
783                let count = self.0.trailing_zeros() as usize;
784                if count >= Self::LANES {
785                    None
786                } else {
787                    Some(count)
788                }
789            }
790
791            // End of recursion functions
792            fn from_fn<F>(arch: A, mut f: F) -> Self
793            where
794                F: FnMut(usize) -> bool,
795            {
796                let mut x: $repr = 0;
797                for i in 0..Self::LANES {
798                    if f(i) {
799                        x |= (1 << i);
800                    }
801                }
802                Self::from_underlying(arch, x)
803            }
804
805            #[inline(always)]
806            fn any(self) -> bool {
807                self.0 != 0
808            }
809
810            #[inline(always)]
811            fn all(self) -> bool {
812                let v: u64 = self.0.into();
813
814                // We again need to handle 64-wide masks differently.
815                if $N == 64 {
816                    v == u64::MAX
817                } else {
818                    v == (1 << $N) - 1
819                }
820            }
821
822            #[inline(always)]
823            fn count(self) -> usize {
824                // We keep the invariant that all constructors of `BitMask` must zero out
825                // upper bits (for 2 and 4 width BitMasks).
826                self.0.count_ones() as usize
827            }
828        }
829
830        // Transformation from bitmask to integer.
831        impl From<BitMask<$N>> for $repr {
832            fn from(value: BitMask<$N>) -> Self {
833                value.to_underlying()
834            }
835        }
836    };
837}
838
839// Stamp out a bunch of implementations.
840impl_simd_mask_for_bitmask!(1, u8, 0x1);
841impl_simd_mask_for_bitmask!(2, u8, 0x3);
842impl_simd_mask_for_bitmask!(4, u8, 0xf);
843impl_simd_mask_for_bitmask!(8, u8, u8::MAX);
844impl_simd_mask_for_bitmask!(16, u16, u16::MAX);
845impl_simd_mask_for_bitmask!(32, u32, u32::MAX);
846impl_simd_mask_for_bitmask!(64, u64, u64::MAX);
847
848#[cfg(test)]
849mod test_traits {
850    use rand::{
851        SeedableRng,
852        distr::{Distribution, StandardUniform},
853        rngs::StdRng,
854    };
855
856    use super::*;
857    use crate::{
858        ARCH, arch,
859        splitjoin::{LoHi, SplitJoin},
860        test_utils,
861    };
862
863    // Allow unsigned 128-bit integers to be converted to narrow types.
864    trait FromU128 {
865        fn from_(value: u128) -> Self;
866    }
867
868    impl FromU128 for u8 {
869        fn from_(value: u128) -> Self {
870            value as u8
871        }
872    }
873    impl FromU128 for u16 {
874        fn from_(value: u128) -> Self {
875            value as u16
876        }
877    }
878    impl FromU128 for u32 {
879        fn from_(value: u128) -> Self {
880            value as u32
881        }
882    }
883    impl FromU128 for u64 {
884        fn from_(value: u128) -> Self {
885            value as u64
886        }
887    }
888
889    /// Test that bitmasks faithfully implement the trait `SIMDMask`.
890    ///
891    /// Since conversion between `u128` and arbitrary generic parameters `T` are now
892    /// allowed, we take a conversion function to do this for us with all the known type
893    /// information.
894    fn test_bitmask_impl<const N: usize, T>()
895    where
896        Const<N>: SupportedLaneCount, // this value of `N` has a bitmask representation.
897        T: std::fmt::Debug + std::cmp::Eq + FromU128 + From<BitMask<N, arch::Current>>,
898        BitMask<N, arch::Current>: SIMDMask<Arch = arch::Current, Underlying = T>,
899    {
900        const MAXLEN: usize = 64;
901        assert_eq!(N, BitMask::<N, arch::Current>::LANES);
902
903        // The bit-mask corresponding to all lanes.
904        let one = 1_u128;
905
906        let all: u128 = (one << N) - one;
907
908        for i in 0..=MAXLEN {
909            let mask = BitMask::<N, arch::Current>::keep_first(arch::current(), i);
910
911            let expected: u128 = ((one << i) - one) & all;
912
913            // Cannot use "as" since T is not known to be a primitive type ...
914            assert_eq!(mask.to_underlying(), T::from_(expected));
915            assert_eq!(T::from_(expected), mask.into());
916            for j in 0..=MAXLEN {
917                let b = mask.get_unchecked(j);
918                let o = mask.get(j);
919
920                let expected: bool = j < i;
921                if j < N {
922                    assert_eq!(b, expected);
923                    assert_eq!(o.unwrap(), expected);
924                } else {
925                    assert!(!b);
926                    assert!(o.is_none());
927                }
928            }
929
930            // Check reductions.
931            if i == 0 {
932                assert!(!mask.any());
933                assert!(!mask.all());
934                assert!(mask.none());
935            } else if i >= N {
936                assert!(mask.any());
937                assert!(mask.all());
938                assert!(!mask.none());
939            } else {
940                assert!(mask.any());
941                assert!(!mask.all());
942                assert!(!mask.none());
943            }
944        }
945    }
946
947    #[test]
948    fn test_bitmask() {
949        test_bitmask_impl::<1, u8>();
950        test_bitmask_impl::<2, u8>();
951        test_bitmask_impl::<4, u8>();
952        test_bitmask_impl::<8, u8>();
953        test_bitmask_impl::<16, u16>();
954        test_bitmask_impl::<32, u32>();
955        test_bitmask_impl::<64, u64>();
956    }
957
958    fn test_bitmask_splitjoin_impl<const N: usize, const NHALF: usize>(ntrials: usize, seed: u64)
959    where
960        Const<N>: SupportedLaneCount,
961        Const<NHALF>: SupportedLaneCount,
962        BitMask<N, arch::Current>:
963            SIMDMask<Arch = arch::Current> + SplitJoin<Halved = BitMask<NHALF, arch::Current>>,
964        BitMask<NHALF, arch::Current>: SIMDMask<Arch = arch::Current>,
965    {
966        let mut rng = StdRng::seed_from_u64(seed);
967        for _ in 0..ntrials {
968            let base = BitMask::<N>::from_fn(ARCH, |_| StandardUniform {}.sample(&mut rng));
969            let LoHi { lo, hi } = base.split();
970
971            for i in 0..NHALF {
972                assert_eq!(base.get(i).unwrap(), lo.get(i).unwrap());
973            }
974
975            for i in 0..NHALF {
976                assert_eq!(base.get(i + NHALF).unwrap(), hi.get(i).unwrap());
977            }
978
979            let joined = BitMask::<N>::join(LoHi::new(lo, hi));
980            bitmasks_equal(base, joined);
981        }
982    }
983
984    #[test]
985    fn test_bitmask_splitjoin() {
986        test_bitmask_splitjoin_impl::<2, 1>(100, 0xcbdbdca310caec88);
987        test_bitmask_splitjoin_impl::<4, 2>(100, 0x9c8b9b6c70d941c5);
988        test_bitmask_splitjoin_impl::<8, 4>(100, 0xc81a25918b683d39);
989        test_bitmask_splitjoin_impl::<16, 8>(50, 0xad045b437c3fa0cc);
990        test_bitmask_splitjoin_impl::<32, 16>(50, 0xe710ccdbbd329c77);
991        test_bitmask_splitjoin_impl::<64, 32>(25, 0xd6697e3c534fc134);
992    }
993
994    // Explicit tests to ensure that upper bits are masked out during construction of
995    // 2 and 4 bit masks.
996    #[test]
997    fn test_zeroing() {
998        let b = BitMask::<2>::from_underlying(arch::current(), 0xff);
999        assert_eq!(b.to_underlying(), 0x3);
1000        assert_eq!(b.count(), 2);
1001
1002        let b = BitMask::<4>::from_underlying(arch::current(), 0xff);
1003        assert_eq!(b.to_underlying(), 0xf);
1004        assert_eq!(b.count(), 4);
1005    }
1006
1007    fn bitmasks_equal<const N: usize>(x: BitMask<N, arch::Current>, y: BitMask<N, arch::Current>)
1008    where
1009        Const<N>: SupportedLaneCount,
1010        BitMask<N, arch::Current>: SIMDMask,
1011    {
1012        assert_eq!(x.0, y.0);
1013    }
1014
1015    // A helper macro to run a BitMask through the SIMDMask test routines.
1016    macro_rules! test_simdmask {
1017        ($N:literal) => {
1018            paste::paste! {
1019                #[test]
1020                fn [<test_simd_mask_ $N>]() {
1021                    let arch = arch::current();
1022                    test_utils::mask::test_keep_first::<BitMask<$N, arch::Current>, $N, _, _>(
1023                        arch,
1024                        bitmasks_equal
1025                    );
1026                    test_utils::mask::test_from_fn::<BitMask<$N, arch::Current>, $N, _, _>(
1027                        arch,
1028                        bitmasks_equal
1029                    );
1030                    test_utils::mask::test_reductions::<BitMask<$N, arch::Current>, $N, _, _>(
1031                        arch,
1032                        bitmasks_equal
1033                    );
1034                    test_utils::mask::test_first::<BitMask<$N, arch::Current>, $N, _, _>(
1035                        arch,
1036                        bitmasks_equal
1037                    );
1038                }
1039            }
1040        };
1041    }
1042
1043    test_simdmask!(2);
1044    test_simdmask!(4);
1045    test_simdmask!(8);
1046    test_simdmask!(16);
1047    test_simdmask!(32);
1048    test_simdmask!(64);
1049}