Skip to main content

diskann_wide/
traits.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::{
7    arch,
8    bitmask::{BitMask, FromInt},
9    constant::{Const, SupportedLaneCount},
10};
11
12/// Rust currently lacks the ability to use a trait's associated constants as constraints
13/// or constant parameters to other type definitions within the same trait.
14///
15/// The use of the `Const` type moves the const generic into the type domain, allowing us
16/// to properly constrain and define other associated items.
17///
18/// In particular, we currently can't do nice things like:
19/// ```ignore
20/// trait MyTrait {
21///     const MY_CONSTANT: usize;
22///     type ArrayType = [f32; Self::MY_CONSTANT];
23///                            ----------------- Currently unsupported.
24/// }
25/// ```
26///
27/// See:
28/// - <https://users.rust-lang.org/t/limitation-of-associated-const-in-traits/73491/2>
29/// - <https://github.com/rust-lang/rust/issues/76560>
30pub trait ArrayType<T>: SupportedLaneCount {
31    type Type;
32}
33
34/// Map scalar + lengths to arrays.
35impl<T, const N: usize> ArrayType<T> for Const<N>
36where
37    Const<N>: SupportedLaneCount,
38{
39    type Type = [T; N];
40}
41
42/// Stable Rust does not allow expressions involving compile-time computation with
43/// const generic parameters: <http://github.com/rust-lang/rust/issues/76560>
44///
45/// This makes it difficult to go from a const parameter defining the number of SIMD lanes
46/// to an appropriately sized mask.
47///
48/// This helper trait provides a level of indirection to map SIMD representations to
49/// the associated bitmask.
50pub trait BitMaskType<A: arch::Sealed>: SupportedLaneCount {
51    type Type; // should always be a `BitMask`.
52}
53
54impl<A, const N: usize> BitMaskType<A> for Const<N>
55where
56    Const<N>: SupportedLaneCount,
57    A: arch::Sealed,
58{
59    type Type = BitMask<N, A>;
60}
61
62/// Convert `Self` to the SIMD type `T`. This is mainly useful when implementing fallback
63/// operations through [`crate::Emulated`] to restore the original SIMD type.
64pub trait AsSIMD<T>: Copy
65where
66    T: SIMDVector,
67{
68    fn as_simd(self, arch: T::Arch) -> T;
69}
70
71/// A logical mask for SIMD operations.
72///
73/// The representation of this type varies between architectures and micro-architectures.
74/// For example:
75///
76/// * On AVX 2 systems, a SIMD mask for type/length pairs `(T, N)` consists of a SIMD
77///   register of an unsigned integer with the same size of `T` and length `N`.
78///
79///   The semantics of such registers are to allow operations in lanes where the top-most
80///   bit is set to 1.
81///
82/// * On AVX-512 systems, the story is much simpler as the masks used in that instruction
83///   set are simply the corresponding bit mask.
84///
85///   So a mask for 8-wide operations is simply an 8-bit unsigned integer.
86///
87/// * Emulated systems should use a bit-mask for the most compact representation.
88pub trait SIMDMask: Copy + std::fmt::Debug {
89    /// The architecture type this struct belongs to.
90    type Arch: arch::Sealed;
91
92    /// The type of the underlying intrinsic.
93    type Underlying: Copy + std::fmt::Debug;
94
95    /// The bitmask associated with the logical mask.
96    type BitMask: SIMDMask<Arch = Self::Arch> + Into<Self> + From<Self>;
97
98    /// The number of lanes in the bitmask.
99    const LANES: usize;
100
101    /// Whether or not this mask implementation is a bit mask.
102    const ISBITS: bool;
103
104    /// Return the architecture object associated with this vector.
105    fn arch(self) -> Self::Arch;
106
107    /// Retrieve the underlying type.
108    /// This will always be an unsigned integer of the minimum width required to contain
109    /// `LANES` bits.
110    fn to_underlying(self) -> Self::Underlying;
111
112    /// Construct the mask from the underlying type.
113    fn from_underlying(arch: Self::Arch, value: Self::Underlying) -> Self;
114
115    /// Return `true` if lane `i` is set and `false` otherwise.
116    ///
117    /// This method is unchecked, but safe in the sense that if `i >= LANES` false will
118    /// always be returned. No out of bounds access will be made, but no error indication
119    /// will be provided.
120    fn get_unchecked(&self, i: usize) -> bool;
121
122    /// Efficiently construct a new mask with the first `i` bits set and the remainder
123    /// set to zero.
124    ///
125    /// If `i >= LANES` then all bits will be set.
126    fn keep_first(arch: Self::Arch, i: usize) -> Self;
127
128    /// Return the first set index in the mask or `None` if no entries are set.
129    fn first(&self) -> Option<usize> {
130        self.bitmask().first()
131    }
132
133    //////////////////////////////
134    // Provided Implementations //
135    //////////////////////////////
136
137    /// Return the associated BitMask for this Mask.
138    fn bitmask(self) -> Self::BitMask {
139        <Self::BitMask as From<Self>>::from(self)
140    }
141
142    /// Return `true` if lane `i` is set and `false` otherwise. Returns an empty `Option`
143    /// if the index `i` is out-of-bounds.
144    fn get(&self, i: usize) -> Option<bool> {
145        if i >= Self::LANES {
146            None
147        } else {
148            Some(self.get_unchecked(i))
149        }
150    }
151
152    /// Construct a mask based on the result of invoking `f` once for each element in the range
153    /// `0..Self::LANES` in order.
154    ///
155    /// In the returned mask `m`, `m.get(0)` corresponds to the value of `f(0)`. Similarly,
156    /// `m.get(1)` corresponds to `f(1)` etc.
157    #[inline(always)]
158    fn from_fn<F>(arch: Self::Arch, f: F) -> Self
159    where
160        F: FnMut(usize) -> bool,
161    {
162        // Recurse to BitMask.
163        Self::BitMask::from_fn(arch, f).into()
164    }
165
166    /// Return `true` if any lane in the mask is set. Otherwise, return `false`.
167    #[inline(always)]
168    fn any(self) -> bool {
169        // Recurse to BitMask.
170        <Self::BitMask as From<Self>>::from(self).any()
171    }
172
173    /// Return `true` if all lanes in the mask are set. Otherwise, return `false`.
174    #[inline(always)]
175    fn all(self) -> bool {
176        // Recurse to BitMask.
177        <Self::BitMask as From<Self>>::from(self).all()
178    }
179
180    /// Return `true` if no lanes in the mask are set. Otherwise, return `false`.
181    #[inline(always)]
182    fn none(self) -> bool {
183        !self.any()
184    }
185
186    /// Return the number of lanes that evaluate to `true`.
187    #[inline(always)]
188    fn count(self) -> usize {
189        // Recurse to BitMask.
190        <Self::BitMask as From<Self>>::from(self).count()
191    }
192}
193
194/// A trait representing minimal behavior for a SIMD-like vector.
195///
196/// A SIMDVector can be thought of as a homogeneous array `[T; N]` (with potentially
197/// stricter alignment requirements) that generally behave for arithmetic purposes like
198/// scalars in the sense that if
199/// ```ignore
200/// fn add(a: V, b: V) -> V
201/// where V: SIMDVector {
202///     a + b
203/// }
204/// ```
205/// will have the same semantics of broadcasting the `+` operation across all lanes in the
206/// vector.
207pub trait SIMDVector: Copy + std::fmt::Debug {
208    /// The architecture this vector belongs to.
209    type Arch: arch::Sealed;
210
211    /// The type of each element in the vector.
212    type Scalar: Copy + std::fmt::Debug;
213
214    /// The underlying representation.
215    type Underlying: Copy;
216
217    /// The number of lanes in the vector.
218    const LANES: usize;
219
220    /// The value of `LANES` but in the type domain so we can use it to constrain other
221    /// aspects of this trait.
222    ///
223    /// Should be the type `Const<Self::LANES>`.
224    type ConstLanes: ArrayType<Self::Scalar> + BitMaskType<Self::Arch>;
225
226    /// The expanded logical mask representation.
227    /// This may-or-may-not actually be a bitmask, but should be easily convertible to and
228    /// from a bitmask.
229    type Mask: SIMDMask<Arch = Self::Arch>
230        + From<<Self::ConstLanes as BitMaskType<Self::Arch>>::Type>
231        + Into<<Self::ConstLanes as BitMaskType<Self::Arch>>::Type>;
232
233    /// Whether or not this is an emulated vector.
234    ///
235    /// Emulated vectors are backed by Rust arrays and use scalar loops to implement
236    /// arithmetic operations.
237    const EMULATED: bool;
238
239    /// Return the architecture object associated with this vector.
240    ///
241    /// # NOTE
242    ///
243    /// This is safe because construction of `self` serves as the witness that we are on
244    /// a compatible architecture.
245    fn arch(self) -> Self::Arch;
246
247    /// Return the default value for the type. This is always the numeric 0 for the
248    /// associated scalar type.
249    fn default(arch: Self::Arch) -> Self;
250
251    /// Return the underlying type.
252    fn to_underlying(self) -> Self::Underlying;
253
254    /// Construct from the underlying type.
255    fn from_underlying(arch: Self::Arch, repr: Self::Underlying) -> Self;
256
257    /// Retrieve the contents as an array.
258    fn to_array(self) -> <Self::ConstLanes as ArrayType<Self::Scalar>>::Type;
259
260    /// Construct from the associated array.
261    ///
262    /// The argument `arch` provides a "proof of compatibility" as `A` can only be safely
263    /// instantiated when all the requirements for the architecture are met.
264    fn from_array(arch: Self::Arch, x: <Self::ConstLanes as ArrayType<Self::Scalar>>::Type)
265    -> Self;
266
267    /// Broadcast the provided scalar across all lanes.
268    ///
269    /// The argument `arch` provides a "proof of compatibility" as `A` can only be safely
270    /// instantiated when all the requirements for the architecture are met.
271    fn splat(arch: Self::Arch, value: Self::Scalar) -> Self;
272
273    /// Return the number of lanes in this vector.
274    fn num_lanes() -> usize {
275        Self::LANES
276    }
277
278    /// Load `<Self as SIMDVector>::LANES` number of elements starting at the provided
279    /// pointer.
280    ///
281    /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
282    /// not need to be stricter.
283    ///
284    /// # Safety
285    ///
286    /// A contiguous read of `<Self as SIMDVector>::LANES` must touch valid memory.
287    unsafe fn load_simd(arch: Self::Arch, ptr: *const <Self as SIMDVector>::Scalar) -> Self;
288
289    /// Load `<Self as SIMDVector>::LANES` number of elements starting at the provided
290    /// pointer.
291    ///
292    /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
293    /// not need to be stricter.
294    ///
295    /// Entries in the mask that evaluate to `false` will not be accessed.
296    /// This makes it safe to use this function with lanes masked out that would otherwise
297    /// cross a page boundary or otherwise cause an out-of-bounds read.
298    ///
299    /// # Safety
300    ///
301    /// Offsets from the `ptr` where the mask evaluates to true must be dereferenceable to
302    /// the underlying scalar type.
303    unsafe fn load_simd_masked_logical(
304        arch: Self::Arch,
305        ptr: *const <Self as SIMDVector>::Scalar,
306        mask: <Self as SIMDVector>::Mask,
307    ) -> Self;
308
309    /// The same as `load_simd_masked_logical` but taking a BitMask instead.
310    ///
311    /// No load attempt will be made to lanes that are masked out.
312    ///
313    /// # Safety
314    ///
315    /// Offsets from the `ptr` where the mask evaluates to true must be dereferenceable to
316    /// the underlying scalar type. For implementations using the provided default, the
317    /// conversion from the bitmask to the actual mask must be correct.
318    #[inline(always)]
319    unsafe fn load_simd_masked(
320        arch: Self::Arch,
321        ptr: *const <Self as SIMDVector>::Scalar,
322        mask: <<Self as SIMDVector>::ConstLanes as BitMaskType<Self::Arch>>::Type,
323    ) -> Self {
324        // SAFETY: Bitmasks must be convertible to their corresponding logical mask.
325        // When the logical mask **is** a bitmask, this is a no-op.
326        unsafe { Self::load_simd_masked_logical(arch, ptr, mask.into()) }
327    }
328
329    /// The same as `load_simd_masked_logical`, but potentially specialized for situations
330    /// where it is known that some number of first elements will be accessed.
331    ///
332    /// If `first` is greater than or equal to the number of lanes, then all lanes will be
333    /// loaded.
334    ///
335    /// # Safety
336    ///
337    /// A contiguous read of `first.min(<Self as SIMDVector>::LANES)` must be valid.
338    #[inline(always)]
339    unsafe fn load_simd_first(
340        arch: Self::Arch,
341        ptr: *const <Self as SIMDVector>::Scalar,
342        first: usize,
343    ) -> Self {
344        // SAFETY: The implementation of `SIMDMask` must be correct.
345        unsafe {
346            Self::load_simd_masked_logical(
347                arch,
348                ptr,
349                <Self as SIMDVector>::Mask::keep_first(arch, first),
350            )
351        }
352    }
353
354    /// Store `<Self as SIMDVector>::LANES` number of elements contiguously starting at the
355    /// provided pointer.
356    ///
357    /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
358    /// not need to be stricter.
359    ///
360    /// # Safety
361    ///
362    /// The pointed-to memory must adhere to Rust's exclusive reference rules.
363    ///
364    /// A contiguous store of `<Self as SIMDVector>::LANES` must touch valid memory.
365    unsafe fn store_simd(self, ptr: *mut <Self as SIMDVector>::Scalar);
366
367    /// Store `<Self as SIMDVector>::LANES` number of elements starting at the provided
368    /// pointer.
369    ///
370    /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
371    /// not need to be stricter.
372    ///
373    /// Entries in the mask that evaluate to `false` will not be accessed.
374    /// This makes it safe to use this function with lanes masked out that would otherwise
375    /// cross a page boundary or otherwise cause an out-of-bounds write.
376    ///
377    /// # Safety
378    ///
379    /// The pointed-to memory must adhere to Rust's exclusive reference rules.
380    ///
381    /// Offsets from the `ptr` where the mask evaluates to true must be mutably
382    /// dereferenceable to the underlying scalar type.
383    unsafe fn store_simd_masked_logical(
384        self,
385        ptr: *mut <Self as SIMDVector>::Scalar,
386        mask: <Self as SIMDVector>::Mask,
387    );
388
389    /// The same as `store_simd_masked_logical` but taking a BitMask instead.
390    ///
391    /// No store attempt will be made to lanes that are masked out.
392    ///
393    /// # Safety
394    ///
395    /// The pointed-to memory must adhere to Rust's exclusive reference rules.
396    ///
397    /// Offsets from the `ptr` where the mask evaluates to true must be mutably
398    /// dereferenceable to the underlying scalar type.
399    ///
400    /// For implementations using the provided default, the conversion from the bitmask to
401    /// the actual mask must be correct.
402    #[inline(always)]
403    unsafe fn store_simd_masked(
404        self,
405        ptr: *mut <Self as SIMDVector>::Scalar,
406        mask: <<Self as SIMDVector>::ConstLanes as BitMaskType<Self::Arch>>::Type,
407    ) {
408        // SAFETY: Bitmasks must be convertible to their corresponding logical mask.
409        // When the logical mask **is** a bitmask, this is a no-op.
410        unsafe { self.store_simd_masked_logical(ptr, mask.into()) }
411    }
412
413    /// The same as `store_simd_masked_logical`, but potentially specialized for situations
414    /// where it is known that some number of first elements will be accessed.
415    ///
416    /// If `first` is greater than or equal to the number of lanes, then all lanes will be
417    /// written.
418    ///
419    /// # Safety
420    ///
421    /// The pointed-to memory must adhere to Rust's exclusive reference rules.
422    ///
423    /// A contiguous write of `first.min(<Self as SIMDVector>::LANES)` must be valid.
424    #[inline(always)]
425    unsafe fn store_simd_first(self, ptr: *mut <Self as SIMDVector>::Scalar, first: usize) {
426        // SAFETY: The implementation of `SIMDMask` must be correct.
427        unsafe {
428            self.store_simd_masked_logical(
429                ptr,
430                <Self as SIMDVector>::Mask::keep_first(self.arch(), first),
431            )
432        }
433    }
434
435    /// Perform a numeric cast on each element, returning a new SIMD vector.
436    ///
437    /// See also: [`SIMDCast`].
438    #[inline(always)]
439    fn cast<T>(self) -> <Self as SIMDCast<T>>::Cast
440    where
441        Self: SIMDCast<T>,
442    {
443        self.simd_cast()
444    }
445}
446
447/// Efficiently perform the operation
448/// ```ignore
449/// self * rhs + accumulator
450/// ```
451/// with the following semantics dependent on the associated scalar type.
452///
453/// * floating point: Perform a fused multiply-add, implementing the operation with only a
454///   single rounding instance.
455///
456/// * integer: Perform the multiplication followed by the accumulation. Both binary
457///   operations will be performed using wrap-around arithmetic.
458pub trait SIMDMulAdd {
459    fn mul_add_simd(self, rhs: Self, accumulator: Self) -> Self;
460}
461
462/// Efficiently retrieve the pairwise minimum or maximum for the two arguments.
463///
464/// Each function comes in two flavors:
465///
466/// * Standard (suffixed): Compute the minimum or maximum in a way that is equivalent to
467///   Rust's built-in minimum or maximum functions.
468///
469///   When the scalar type is integral, the behavior is unambiguous.
470///
471///   When the scalar type is a floating point and one value of a pair is NaN, the other
472///   value is returned. When the result is zero, either a positive or a negative zero can
473///   be returned.
474///
475/// * Fast (unsuffixed): Compute the minimum or maximum using the fastest possible method
476///   on the given architecture with non-standard NaN handling.
477///
478///   When the scalar type is integral, the behavior is the same as the standard
479///   implementations.
480///
481///   When the scalar type is a floating point, the implementation is allowed to differ
482///   with respect to NaN handling. That is, when one of the arguments is NaN, the
483///   implementation is allowed to return **either** the other argument (like the standard
484///   implementation) or NaN. Like the standard implementation, if the result is zero, then
485///   a zero of either sign can be returned.
486///
487///   This method should be preferred when precise NaN handling is not needed as it can be
488///   more efficient.
489pub trait SIMDMinMax: Sized {
490    /// Return the pairwise minimum of `self` and `rhs`, subject to looser NaN handling.
491    fn min_simd(self, rhs: Self) -> Self;
492
493    /// Return the pairwise minimum of `self` and `rhs` as if by applying the standard
494    /// library's `min` method for the scalar type.
495    #[inline(always)]
496    fn min_simd_standard(self, rhs: Self) -> Self {
497        self.min_simd(rhs)
498    }
499
500    /// Return the pairwise maximum of `self` and `rhs`, subject to looser NaN handling.
501    fn max_simd(self, rhs: Self) -> Self;
502
503    /// Return the pairwise maximum of `self` and `rhs` as if by applying the standard
504    /// library's `max` method for the scalar type.
505    #[inline(always)]
506    fn max_simd_standard(self, rhs: Self) -> Self {
507        self.max_simd(rhs)
508    }
509}
510
511/// Take the absolute value of each lane.
512///
513/// # Notes
514///
515/// For signed integer types T, this works as expected for all values except for `T::MIN`,
516/// in which case `T::MIN` is returned. This keeps the behavior in line with hardware
517/// intrinsics.
518///
519/// A correct answer can be retrieved by casting the result to the equivalent unsigned
520/// integer.
521pub trait SIMDAbs {
522    fn abs_simd(self) -> Self;
523}
524
525/// A SIMD equivalent of `std::cmp::PartialEq`.
526///
527/// Instead of a boolean, return `Self::Mask` containing the result of the element-wise
528/// comparison of the two vectors.
529pub trait SIMDPartialEq: SIMDVector {
530    /// SIMD equivalent of `std::cmp::PartialEq::eq`, applying the latter trait to each
531    /// lane-wise pair of elements in `self` and `other`.
532    fn eq_simd(self, other: Self) -> Self::Mask;
533
534    /// SIMD equivalent of `std::cmp::PartialEq::neq`, applying the latter trait to each
535    /// lane-wise pair of elements in `self` and `other`.
536    fn ne_simd(self, other: Self) -> Self::Mask;
537}
538
539/// A SIMD equivalent of `std::cmp::PartialOrd`.
540///
541/// Instead of a boolean, return `Self::Mask` containing the result of the element-wise
542/// comparisons of the two vectors.
543pub trait SIMDPartialOrd: SIMDVector {
544    /// SIMD equivalent of `std::cmp::PartialOrd::lt`.
545    fn lt_simd(self, other: Self) -> Self::Mask;
546
547    /// SIMD equivalent of `std::cmp::PartialOrd::le`.
548    fn le_simd(self, other: Self) -> Self::Mask;
549
550    //////////////////////
551    // Provided Methods //
552    //////////////////////
553
554    /// SIMD equivalent of `std::cmp::PartialOrd::gt`.
555    ///
556    /// Types are free to override the provided method if a more efficient implementation
557    /// is possible.
558    #[inline(always)]
559    fn gt_simd(self, other: Self) -> Self::Mask {
560        other.lt_simd(self)
561    }
562
563    /// SIMD equivalent of `std::cmp::PartialOrd::ge`.
564    ///
565    /// Types are free to override the provided method if a more efficient implementation
566    /// is possible.
567    #[inline(always)]
568    fn ge_simd(self, other: Self) -> Self::Mask {
569        other.le_simd(self)
570    }
571}
572
573/// Perform a pairwise reducing sum of all lanes in the vector and return the result as a
574/// scalar.
575///
576/// For example, the summing pattern for a vector of 8 elements is as follows:
577/// ```text
578/// let v0 = [x0, x1, x2, x3, x4, x5, x6, x7];
579/// let v1 = [v0[0] + v0[4], v0[1] + v0[5], v0[2] + v0[6], v0[3] + v0[7]];
580/// let v2 = [v1[0] + v1[2], v1[1] + v1[3]];
581/// v2[0] + v2[1]
582/// ```
583pub trait SIMDSumTree: SIMDVector {
584    fn sum_tree(self) -> <Self as SIMDVector>::Scalar;
585}
586
587/// A vectorized "if else".
588pub trait SIMDSelect<V: SIMDVector>: SIMDMask {
589    fn select(self, x: V, y: V) -> V;
590}
591
592/// Optimized dot-product style accumulation.
593///
594/// This tries to match against intrinsics like:
595///
596/// * `_mm256_madd_epi16`
597/// * `_mm256_dpbusd_epi32`
598///
599/// The gist is to perform element-wise multiplication between left and right, promoting the
600/// result to the element-type of `Self`, adding adjacent entries
601///
602/// # Precise Enumeration of Semantics for Implementations
603///
604/// The semantics depend on the source and destination type, but are intended to be the same
605/// for each type combination across architectures.
606///
607/// ## `SIMDDotProduct<i16x16> for i32x8`
608///
609/// 1. Perform multiplication as `i16x16 x i16x16` as if converting each lane to `i32`,
610///    resulting in effectively `i32x16`. No overflow can happen.
611/// 2. Add together adjacent pairs in the resulting `i32x16` to yield `i32x8`. Again, this
612///    step cannot overflow.
613/// 3. Add the resulting `i32x8` into `Self`, returning the result.
614///
615/// ## `SIMDDotProduct<u8x32, i8x32> for i32x8`
616///
617/// 1. Perform multiplication as `i32x32 x i32x32` as if converting each lane to `i32`,
618///    resulting in effectively `i32x32`. No overflow can happen.
619/// 2. Sum together consecutive groups of 4 in the resulting `i32x32` to yield `i32x8`.
620/// 3. Add the resulting `i32x8` into `Self`.
621///
622/// The same applies when the order of `u8x32` and `i8x32` are swapped and for types that
623/// are twice as wide.
624///
625/// The main goal of this function is to hit VNNI instructions like `_mm512_dpbusd_epi32`
626/// that can do the whole operation in a single go on the `V4` architecture. Use of this
627/// instruction is not recommended on non-`V4` architectures.
628pub trait SIMDDotProduct<L: SIMDVector, R: SIMDVector = L> {
629    /// Element wise multiply each component of `left` and `right`, promoting the
630    /// intermediate results to a higher precision.
631    ///
632    /// Then, horizontally add together groups of the accumulated values and add the
633    /// resulting sums to `self`.
634    ///
635    /// The size of the group depends on the relative number of lanes in `Self` and `Source`.
636    ///
637    /// However, it is required that `Self::num_lanes()` evenly divides `Source::num_lanes()`
638    /// so that the size of each group is uniform.
639    fn dot_simd(self, left: L, right: R) -> Self;
640}
641
642/// Perform a bit-cast from one SIMD type to another.
643pub trait SIMDReinterpret<To: SIMDVector>: SIMDVector {
644    fn reinterpret_simd(self) -> To;
645}
646
647/// Perform a numeric cast on the scalar type.
648///
649/// Unlike `From`, this conversion is allowed to be lossy, with similar semantics to
650/// numeric casts in scalar Rust.
651///
652/// This is meant to model Rust's numeric conversion with the "as" operator.
653pub trait SIMDCast<T>: SIMDVector {
654    /// The [`SIMDVector`] type of the result.
655    type Cast: SIMDVector<Scalar = T, ConstLanes = Self::ConstLanes>;
656    /// Perform the cast.
657    fn simd_cast(self) -> Self::Cast;
658}
659
660/// A roll-up of traits required for SIMD floating point types.
661pub trait SIMDFloat:
662    SIMDVector
663    + std::ops::Add<Output = Self>
664    + std::ops::Mul<Output = Self>
665    + std::ops::Sub<Output = Self>
666    + SIMDMulAdd
667    + SIMDMinMax
668    + SIMDPartialEq
669    + SIMDPartialOrd
670{
671}
672
673impl<T> SIMDFloat for T where
674    T: SIMDVector
675        + std::ops::Add<Output = Self>
676        + std::ops::Mul<Output = Self>
677        + std::ops::Sub<Output = Self>
678        + SIMDMulAdd
679        + SIMDMinMax
680        + SIMDPartialEq
681        + SIMDPartialOrd
682{
683}
684
685/// A roll-up of traits required for SIMD integer types.
686pub trait SIMDUnsigned:
687    SIMDVector
688    + std::ops::Add<Output = Self>
689    + std::ops::Mul<Output = Self>
690    + std::ops::Sub<Output = Self>
691    + std::ops::BitAnd<Output = Self>
692    + std::ops::BitOr<Output = Self>
693    + std::ops::BitXor<Output = Self>
694    + std::ops::Shr<Output = Self>
695    + std::ops::Shl<Output = Self>
696    + std::ops::Shr<Self::Scalar, Output = Self>
697    + std::ops::Shl<Self::Scalar, Output = Self>
698    + SIMDMulAdd
699    + SIMDPartialEq
700    + SIMDPartialOrd
701{
702}
703
704impl<T> SIMDUnsigned for T where
705    T: SIMDVector
706        + std::ops::Add<Output = Self>
707        + std::ops::Mul<Output = Self>
708        + std::ops::Sub<Output = Self>
709        + std::ops::BitAnd<Output = Self>
710        + std::ops::BitOr<Output = Self>
711        + std::ops::BitXor<Output = Self>
712        + std::ops::Shr<Output = Self>
713        + std::ops::Shl<Output = Self>
714        + std::ops::Shr<Self::Scalar, Output = Self>
715        + std::ops::Shl<Self::Scalar, Output = Self>
716        + SIMDMulAdd
717        + SIMDPartialEq
718        + SIMDPartialOrd
719{
720}
721
722pub trait SIMDSigned: SIMDUnsigned + SIMDAbs {}
723impl<T> SIMDSigned for T where T: SIMDUnsigned + SIMDAbs {}
724
725/// Element-wise zip and unzip of two half-width SIMD vectors.
726///
727/// `zip` interleaves elements from two halves into one full-width vector:
728///   `zip([a0, a1, …], [b0, b1, …]) = [a0, b0, a1, b1, …]`
729///
730/// `unzip` is the inverse, separating even- and odd-indexed elements:
731///   `unzip([a0, b0, a1, b1, …]) = ([a0, a1, …], [b0, b1, …])`
732///
733/// Two additional "flat" methods operate on a single full-width register:
734///
735/// `zip_flat` treats the low half of `self` as one input and the high half as the other,
736/// interleaving them in-place:
737///   `zip_flat([a0, a1, …, b0, b1, …]) = [a0, b0, a1, b1, …]`
738///
739/// `unzip_flat` is the inverse, collecting even-indexed elements into the low half and
740/// odd-indexed elements into the high half:
741///   `unzip_flat([a0, b0, a1, b1, …]) = [a0, a1, …, b0, b1, …]`
742///
743/// `zip` and `unzip` are required. `zip_flat` and `unzip_flat` have default
744/// implementations that delegate through [`SplitJoin::split`] / [`SplitJoin::join`].
745/// Backends should override whichever pair is cheapest on their ISA.
746pub trait ZipUnzip: crate::SplitJoin + Sized {
747    /// Interleave elements from `halves.lo` and `halves.hi` into `Self`.
748    fn zip(halves: crate::LoHi<Self::Halved>) -> Self;
749
750    /// Separate even-indexed elements into `lo` and odd-indexed into `hi`.
751    fn unzip(self) -> crate::LoHi<Self::Halved>;
752
753    /// Interleave in-place: the low half of `self` supplies the even-indexed
754    /// positions and the high half supplies the odd-indexed positions.
755    ///
756    /// Equivalent to `Self::zip(self.split())` but may be implemented with a
757    /// single cross-lane permute on architectures that support it.
758    fn zip_flat(self) -> Self {
759        Self::zip(self.split())
760    }
761
762    /// Deinterleave in-place: even-indexed elements are collected into the low
763    /// half of the result and odd-indexed elements into the high half.
764    ///
765    /// Equivalent to `Self::join(self.unzip())` but may be implemented with a
766    /// single cross-lane permute on architectures that support it.
767    fn unzip_flat(self) -> Self {
768        Self::unzip(self).join()
769    }
770}
771
772// Since it is so difficult to work directly with generic integers, resort to using a macro
773// to stamp out implementations of `SIMDMask` for `BitMask`.
774//
775// The argument `submask` is a bit-pattern to apply to the underlying type and is to mask
776// out upper-bits of the representation for 2 and 4 bit masks.
777macro_rules! impl_simd_mask_for_bitmask {
778    ($N:literal, $repr:ty, $submask:expr) => {
779        impl<A: arch::Sealed> SIMDMask for BitMask<$N, A> {
780            type Arch = A;
781            type Underlying = $repr;
782            type BitMask = Self;
783            const ISBITS: bool = true;
784            const LANES: usize = $N;
785
786            #[inline(always)]
787            fn arch(self) -> A {
788                self.get_arch()
789            }
790
791            #[inline(always)]
792            fn to_underlying(self) -> Self::Underlying {
793                self.0
794            }
795
796            #[inline(always)]
797            fn from_underlying(arch: A, value: Self::Underlying) -> Self {
798                Self::from_int(arch, value)
799            }
800
801            #[inline(always)]
802            fn keep_first(arch: A, i: usize) -> Self {
803                // Ensure that providing a value that is too big still yields sensible
804                // results.
805                let i = i.min(Self::LANES);
806
807                // Handle 64-bit integers properly.
808                // It is expected that the compiler will be able to optimize out this branch
809                // for non-64-bit types.
810                if Self::LANES == 64 && i == 64 {
811                    return Self::from_underlying(arch, Self::Underlying::MAX);
812                }
813
814                let one: u64 = 1;
815                // "as" conversion in Rust performs truncation on the integers.
816                Self::from_underlying(arch, ((one << i) - one) as Self::Underlying)
817            }
818
819            #[inline(always)]
820            fn get_unchecked(&self, i: usize) -> bool {
821                if i >= Self::LANES {
822                    false
823                } else {
824                    (self.0 >> i) % 2 == 1
825                }
826            }
827
828            #[inline(always)]
829            fn first(&self) -> Option<usize> {
830                let count = self.0.trailing_zeros() as usize;
831                if count >= Self::LANES {
832                    None
833                } else {
834                    Some(count)
835                }
836            }
837
838            // End of recursion functions
839            fn from_fn<F>(arch: A, mut f: F) -> Self
840            where
841                F: FnMut(usize) -> bool,
842            {
843                let mut x: $repr = 0;
844                for i in 0..Self::LANES {
845                    if f(i) {
846                        x |= (1 << i);
847                    }
848                }
849                Self::from_underlying(arch, x)
850            }
851
852            #[inline(always)]
853            fn any(self) -> bool {
854                self.0 != 0
855            }
856
857            #[inline(always)]
858            fn all(self) -> bool {
859                let v: u64 = self.0.into();
860
861                // We again need to handle 64-wide masks differently.
862                if $N == 64 {
863                    v == u64::MAX
864                } else {
865                    v == (1 << $N) - 1
866                }
867            }
868
869            #[inline(always)]
870            fn count(self) -> usize {
871                // We keep the invariant that all constructors of `BitMask` must zero out
872                // upper bits (for 2 and 4 width BitMasks).
873                self.0.count_ones() as usize
874            }
875        }
876
877        // Transformation from bitmask to integer.
878        impl From<BitMask<$N>> for $repr {
879            fn from(value: BitMask<$N>) -> Self {
880                value.to_underlying()
881            }
882        }
883    };
884}
885
886// Stamp out a bunch of implementations.
887impl_simd_mask_for_bitmask!(1, u8, 0x1);
888impl_simd_mask_for_bitmask!(2, u8, 0x3);
889impl_simd_mask_for_bitmask!(4, u8, 0xf);
890impl_simd_mask_for_bitmask!(8, u8, u8::MAX);
891impl_simd_mask_for_bitmask!(16, u16, u16::MAX);
892impl_simd_mask_for_bitmask!(32, u32, u32::MAX);
893impl_simd_mask_for_bitmask!(64, u64, u64::MAX);
894
895#[cfg(test)]
896mod test_traits {
897    use rand::{
898        SeedableRng,
899        distr::{Distribution, StandardUniform},
900        rngs::StdRng,
901    };
902
903    use super::*;
904    use crate::{
905        ARCH, arch,
906        splitjoin::{LoHi, SplitJoin},
907        test_utils,
908    };
909
910    // Allow unsigned 128-bit integers to be converted to narrow types.
911    trait FromU128 {
912        fn from_(value: u128) -> Self;
913    }
914
915    impl FromU128 for u8 {
916        fn from_(value: u128) -> Self {
917            value as u8
918        }
919    }
920    impl FromU128 for u16 {
921        fn from_(value: u128) -> Self {
922            value as u16
923        }
924    }
925    impl FromU128 for u32 {
926        fn from_(value: u128) -> Self {
927            value as u32
928        }
929    }
930    impl FromU128 for u64 {
931        fn from_(value: u128) -> Self {
932            value as u64
933        }
934    }
935
936    /// Test that bitmasks faithfully implement the trait `SIMDMask`.
937    ///
938    /// Since conversion between `u128` and arbitrary generic parameters `T` are now
939    /// allowed, we take a conversion function to do this for us with all the known type
940    /// information.
941    fn test_bitmask_impl<const N: usize, T>()
942    where
943        Const<N>: SupportedLaneCount, // this value of `N` has a bitmask representation.
944        T: std::fmt::Debug + std::cmp::Eq + FromU128 + From<BitMask<N, arch::Current>>,
945        BitMask<N, arch::Current>: SIMDMask<Arch = arch::Current, Underlying = T>,
946    {
947        const MAXLEN: usize = 64;
948        assert_eq!(N, BitMask::<N, arch::Current>::LANES);
949
950        // The bit-mask corresponding to all lanes.
951        let one = 1_u128;
952
953        let all: u128 = (one << N) - one;
954
955        for i in 0..=MAXLEN {
956            let mask = BitMask::<N, arch::Current>::keep_first(arch::current(), i);
957
958            let expected: u128 = ((one << i) - one) & all;
959
960            // Cannot use "as" since T is not known to be a primitive type ...
961            assert_eq!(mask.to_underlying(), T::from_(expected));
962            assert_eq!(T::from_(expected), mask.into());
963            for j in 0..=MAXLEN {
964                let b = mask.get_unchecked(j);
965                let o = mask.get(j);
966
967                let expected: bool = j < i;
968                if j < N {
969                    assert_eq!(b, expected);
970                    assert_eq!(o.unwrap(), expected);
971                } else {
972                    assert!(!b);
973                    assert!(o.is_none());
974                }
975            }
976
977            // Check reductions.
978            if i == 0 {
979                assert!(!mask.any());
980                assert!(!mask.all());
981                assert!(mask.none());
982            } else if i >= N {
983                assert!(mask.any());
984                assert!(mask.all());
985                assert!(!mask.none());
986            } else {
987                assert!(mask.any());
988                assert!(!mask.all());
989                assert!(!mask.none());
990            }
991        }
992    }
993
994    #[test]
995    fn test_bitmask() {
996        test_bitmask_impl::<1, u8>();
997        test_bitmask_impl::<2, u8>();
998        test_bitmask_impl::<4, u8>();
999        test_bitmask_impl::<8, u8>();
1000        test_bitmask_impl::<16, u16>();
1001        test_bitmask_impl::<32, u32>();
1002        test_bitmask_impl::<64, u64>();
1003    }
1004
1005    fn test_bitmask_splitjoin_impl<const N: usize, const NHALF: usize>(ntrials: usize, seed: u64)
1006    where
1007        Const<N>: SupportedLaneCount,
1008        Const<NHALF>: SupportedLaneCount,
1009        BitMask<N, arch::Current>:
1010            SIMDMask<Arch = arch::Current> + SplitJoin<Halved = BitMask<NHALF, arch::Current>>,
1011        BitMask<NHALF, arch::Current>: SIMDMask<Arch = arch::Current>,
1012    {
1013        let mut rng = StdRng::seed_from_u64(seed);
1014        for _ in 0..ntrials {
1015            let base = BitMask::<N>::from_fn(ARCH, |_| StandardUniform {}.sample(&mut rng));
1016            let LoHi { lo, hi } = base.split();
1017
1018            for i in 0..NHALF {
1019                assert_eq!(base.get(i).unwrap(), lo.get(i).unwrap());
1020            }
1021
1022            for i in 0..NHALF {
1023                assert_eq!(base.get(i + NHALF).unwrap(), hi.get(i).unwrap());
1024            }
1025
1026            let joined = BitMask::<N>::join(LoHi::new(lo, hi));
1027            bitmasks_equal(base, joined);
1028        }
1029    }
1030
1031    #[test]
1032    fn test_bitmask_splitjoin() {
1033        test_bitmask_splitjoin_impl::<2, 1>(100, 0xcbdbdca310caec88);
1034        test_bitmask_splitjoin_impl::<4, 2>(100, 0x9c8b9b6c70d941c5);
1035        test_bitmask_splitjoin_impl::<8, 4>(100, 0xc81a25918b683d39);
1036        test_bitmask_splitjoin_impl::<16, 8>(50, 0xad045b437c3fa0cc);
1037        test_bitmask_splitjoin_impl::<32, 16>(50, 0xe710ccdbbd329c77);
1038        test_bitmask_splitjoin_impl::<64, 32>(25, 0xd6697e3c534fc134);
1039    }
1040
1041    // Explicit tests to ensure that upper bits are masked out during construction of
1042    // 2 and 4 bit masks.
1043    #[test]
1044    fn test_zeroing() {
1045        let b = BitMask::<2>::from_underlying(arch::current(), 0xff);
1046        assert_eq!(b.to_underlying(), 0x3);
1047        assert_eq!(b.count(), 2);
1048
1049        let b = BitMask::<4>::from_underlying(arch::current(), 0xff);
1050        assert_eq!(b.to_underlying(), 0xf);
1051        assert_eq!(b.count(), 4);
1052    }
1053
1054    fn bitmasks_equal<const N: usize>(x: BitMask<N, arch::Current>, y: BitMask<N, arch::Current>)
1055    where
1056        Const<N>: SupportedLaneCount,
1057        BitMask<N, arch::Current>: SIMDMask,
1058    {
1059        assert_eq!(x.0, y.0);
1060    }
1061
1062    // A helper macro to run a BitMask through the SIMDMask test routines.
1063    macro_rules! test_simdmask {
1064        ($N:literal) => {
1065            paste::paste! {
1066                #[test]
1067                fn [<test_simd_mask_ $N>]() {
1068                    let arch = arch::current();
1069                    test_utils::mask::test_keep_first::<BitMask<$N, arch::Current>, $N, _, _>(
1070                        arch,
1071                        bitmasks_equal
1072                    );
1073                    test_utils::mask::test_from_fn::<BitMask<$N, arch::Current>, $N, _, _>(
1074                        arch,
1075                        bitmasks_equal
1076                    );
1077                    test_utils::mask::test_reductions::<BitMask<$N, arch::Current>, $N, _, _>(
1078                        arch,
1079                        bitmasks_equal
1080                    );
1081                    test_utils::mask::test_first::<BitMask<$N, arch::Current>, $N, _, _>(
1082                        arch,
1083                        bitmasks_equal
1084                    );
1085                }
1086            }
1087        };
1088    }
1089
1090    test_simdmask!(2);
1091    test_simdmask!(4);
1092    test_simdmask!(8);
1093    test_simdmask!(16);
1094    test_simdmask!(32);
1095    test_simdmask!(64);
1096}