diskann_wide/traits.rs
1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use super::{
7 arch,
8 bitmask::{BitMask, FromInt},
9 constant::{Const, SupportedLaneCount},
10};
11
12/// Rust currently lacks the ability to use a trait's associated constants as constraints
13/// or constant parameters to other type definitions within the same trait.
14///
15/// The use of the `Const` type moves the const generic into the type domain, allowing us
16/// to properly constrain and define other associated items.
17///
18/// In particular, we currently can't do nice things like:
19/// ```ignore
20/// trait MyTrait {
21/// const MY_CONSTANT: usize;
22/// type ArrayType = [f32; Self::MY_CONSTANT];
23/// ----------------- Currently unsupported.
24/// }
25/// ```
26///
27/// See:
28/// - <https://users.rust-lang.org/t/limitation-of-associated-const-in-traits/73491/2>
29/// - <https://github.com/rust-lang/rust/issues/76560>
30pub trait ArrayType<T>: SupportedLaneCount {
31 type Type;
32}
33
34/// Map scalar + lengths to arrays.
35impl<T, const N: usize> ArrayType<T> for Const<N>
36where
37 Const<N>: SupportedLaneCount,
38{
39 type Type = [T; N];
40}
41
42/// Stable Rust does not allow expressions involving compile-time computation with
43/// const generic parameters: <http://github.com/rust-lang/rust/issues/76560>
44///
45/// This makes it difficult to go from a const parameter defining the number of SIMD lanes
46/// to an appropriately sized mask.
47///
48/// This helper trait provides a level of indirection to map SIMD representations to
49/// the associated bitmask.
50pub trait BitMaskType<A: arch::Sealed>: SupportedLaneCount {
51 type Type; // should always be a `BitMask`.
52}
53
54impl<A, const N: usize> BitMaskType<A> for Const<N>
55where
56 Const<N>: SupportedLaneCount,
57 A: arch::Sealed,
58{
59 type Type = BitMask<N, A>;
60}
61
62/// Convert `Self` to the SIMD type `T`. This is mainly useful when implementing fallback
63/// operations through [`crate::Emulated`] to restore the original SIMD type.
64pub trait AsSIMD<T>: Copy
65where
66 T: SIMDVector,
67{
68 fn as_simd(self, arch: T::Arch) -> T;
69}
70
71/// A logical mask for SIMD operations.
72///
73/// The representation of this type varies between architectures and micro-architectures.
74/// For example:
75///
76/// * On AVX 2 systems, a SIMD mask for type/length pairs `(T, N)` consists of a SIMD
77/// register of an unsigned integer with the same size of `T` and length `N`.
78///
79/// The semantics of such registers are to allow operations in lanes where the top-most
80/// bit is set to 1.
81///
82/// * On AVX-512 systems, the story is much simpler as the masks used in that instruction
83/// set are simply the corresponding bit mask.
84///
85/// So a mask for 8-wide operations is simply an 8-bit unsigned integer.
86///
87/// * Emulated systems should use a bit-mask for the most compact representation.
88pub trait SIMDMask: Copy + std::fmt::Debug {
89 /// The architecture type this struct belongs to.
90 type Arch: arch::Sealed;
91
92 /// The type of the underlying intrinsic.
93 type Underlying: Copy + std::fmt::Debug;
94
95 /// The bitmask associated with the logical mask.
96 type BitMask: SIMDMask<Arch = Self::Arch> + Into<Self> + From<Self>;
97
98 /// The number of lanes in the bitmask.
99 const LANES: usize;
100
101 /// Whether or not this mask implementation is a bit mask.
102 const ISBITS: bool;
103
104 /// Return the architecture object associated with this vector.
105 fn arch(self) -> Self::Arch;
106
107 /// Retrieve the underlying type.
108 /// This will always be an unsigned integer of the minimum width required to contain
109 /// `LANES` bits.
110 fn to_underlying(self) -> Self::Underlying;
111
112 /// Construct the mask from the underlying type.
113 fn from_underlying(arch: Self::Arch, value: Self::Underlying) -> Self;
114
115 /// Return `true` if lane `i` is set and `false` otherwise.
116 ///
117 /// This method is unchecked, but safe in the sense that if `i >= LANES` false will
118 /// always be returned. No out of bounds access will be made, but no error indication
119 /// will be provided.
120 fn get_unchecked(&self, i: usize) -> bool;
121
122 /// Efficiently construct a new mask with the first `i` bits set and the remainder
123 /// set to zero.
124 ///
125 /// If `i >= LANES` then all bits will be set.
126 fn keep_first(arch: Self::Arch, i: usize) -> Self;
127
128 /// Return the first set index in the mask or `None` if no entries are set.
129 fn first(&self) -> Option<usize> {
130 self.bitmask().first()
131 }
132
133 //////////////////////////////
134 // Provided Implementations //
135 //////////////////////////////
136
137 /// Return the associated BitMask for this Mask.
138 fn bitmask(self) -> Self::BitMask {
139 <Self::BitMask as From<Self>>::from(self)
140 }
141
142 /// Return `true` if lane `i` is set and `false` otherwise. Returns an empty `Option`
143 /// if the index `i` is out-of-bounds.
144 fn get(&self, i: usize) -> Option<bool> {
145 if i >= Self::LANES {
146 None
147 } else {
148 Some(self.get_unchecked(i))
149 }
150 }
151
152 /// Construct a mask based on the result of invoking `f` once for each element in the range
153 /// `0..Self::LANES` in order.
154 ///
155 /// In the returned mask `m`, `m.get(0)` corresponds to the value of `f(0)`. Similarly,
156 /// `m.get(1)` corresponds to `f(1)` etc.
157 #[inline(always)]
158 fn from_fn<F>(arch: Self::Arch, f: F) -> Self
159 where
160 F: FnMut(usize) -> bool,
161 {
162 // Recurse to BitMask.
163 Self::BitMask::from_fn(arch, f).into()
164 }
165
166 /// Return `true` if any lane in the mask is set. Otherwise, return `false`.
167 #[inline(always)]
168 fn any(self) -> bool {
169 // Recurse to BitMask.
170 <Self::BitMask as From<Self>>::from(self).any()
171 }
172
173 /// Return `true` if all lanes in the mask are set. Otherwise, return `false`.
174 #[inline(always)]
175 fn all(self) -> bool {
176 // Recurse to BitMask.
177 <Self::BitMask as From<Self>>::from(self).all()
178 }
179
180 /// Return `true` if no lanes in the mask are set. Otherwise, return `false`.
181 #[inline(always)]
182 fn none(self) -> bool {
183 !self.any()
184 }
185
186 /// Return the number of lanes that evaluate to `true`.
187 #[inline(always)]
188 fn count(self) -> usize {
189 // Recurse to BitMask.
190 <Self::BitMask as From<Self>>::from(self).count()
191 }
192}
193
194/// A trait representing minimal behavior for a SIMD-like vector.
195///
196/// A SIMDVector can be thought of as a homogeneous array `[T; N]` (with potentially
197/// stricter alignment requirements) that generally behave for arithmetic purposes like
198/// scalars in the sense that if
199/// ```ignore
200/// fn add(a: V, b: V) -> V
201/// where V: SIMDVector {
202/// a + b
203/// }
204/// ```
205/// will have the same semantics of broadcasting the `+` operation across all lanes in the
206/// vector.
207pub trait SIMDVector: Copy + std::fmt::Debug {
208 /// The architecture this vector belongs to.
209 type Arch: arch::Sealed;
210
211 /// The type of each element in the vector.
212 type Scalar: Copy + std::fmt::Debug;
213
214 /// The underlying representation.
215 type Underlying: Copy;
216
217 /// The number of lanes in the vector.
218 const LANES: usize;
219
220 /// The value of `LANES` but in the type domain so we can use it to constrain other
221 /// aspects of this trait.
222 ///
223 /// Should be the type `Const<Self::LANES>`.
224 type ConstLanes: ArrayType<Self::Scalar> + BitMaskType<Self::Arch>;
225
226 /// The expanded logical mask representation.
227 /// This may-or-may-not actually be a bitmask, but should be easily convertible to and
228 /// from a bitmask.
229 type Mask: SIMDMask<Arch = Self::Arch>
230 + From<<Self::ConstLanes as BitMaskType<Self::Arch>>::Type>
231 + Into<<Self::ConstLanes as BitMaskType<Self::Arch>>::Type>;
232
233 /// Whether or not this is an emulated vector.
234 ///
235 /// Emulated vectors are backed by Rust arrays and use scalar loops to implement
236 /// arithmetic operations.
237 const EMULATED: bool;
238
239 /// Return the architecture object associated with this vector.
240 ///
241 /// # NOTE
242 ///
243 /// This is safe because construction of `self` serves as the witness that we are on
244 /// a compatible architecture.
245 fn arch(self) -> Self::Arch;
246
247 /// Return the default value for the type. This is always the numeric 0 for the
248 /// associated scalar type.
249 fn default(arch: Self::Arch) -> Self;
250
251 /// Return the underlying type.
252 fn to_underlying(self) -> Self::Underlying;
253
254 /// Construct from the underlying type.
255 fn from_underlying(arch: Self::Arch, repr: Self::Underlying) -> Self;
256
257 /// Retrieve the contents as an array.
258 fn to_array(self) -> <Self::ConstLanes as ArrayType<Self::Scalar>>::Type;
259
260 /// Construct from the associated array.
261 ///
262 /// The argument `arch` provides a "proof of compatibility" as `A` can only be safely
263 /// instantiated when all the requirements for the architecture are met.
264 fn from_array(arch: Self::Arch, x: <Self::ConstLanes as ArrayType<Self::Scalar>>::Type)
265 -> Self;
266
267 /// Broadcast the provided scalar across all lanes.
268 ///
269 /// The argument `arch` provides a "proof of compatibility" as `A` can only be safely
270 /// instantiated when all the requirements for the architecture are met.
271 fn splat(arch: Self::Arch, value: Self::Scalar) -> Self;
272
273 /// Return the number of lanes in this vector.
274 fn num_lanes() -> usize {
275 Self::LANES
276 }
277
278 /// Load `<Self as SIMDVector>::LANES` number of elements starting at the provided
279 /// pointer.
280 ///
281 /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
282 /// not need to be stricter.
283 ///
284 /// # Safety
285 ///
286 /// A contiguous read of `<Self as SIMDVector>::LANES` must touch valid memory.
287 unsafe fn load_simd(arch: Self::Arch, ptr: *const <Self as SIMDVector>::Scalar) -> Self;
288
289 /// Load `<Self as SIMDVector>::LANES` number of elements starting at the provided
290 /// pointer.
291 ///
292 /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
293 /// not need to be stricter.
294 ///
295 /// Entries in the mask that evaluate to `false` will not be accessed.
296 /// This makes it safe to use this function with lanes masked out that would otherwise
297 /// cross a page boundary or otherwise cause an out-of-bounds read.
298 ///
299 /// # Safety
300 ///
301 /// Offsets from the `ptr` where the mask evaluates to true must be dereferenceable to
302 /// the underlying scalar type.
303 unsafe fn load_simd_masked_logical(
304 arch: Self::Arch,
305 ptr: *const <Self as SIMDVector>::Scalar,
306 mask: <Self as SIMDVector>::Mask,
307 ) -> Self;
308
309 /// The same as `load_simd_masked_logical` but taking a BitMask instead.
310 ///
311 /// No load attempt will be made to lanes that are masked out.
312 ///
313 /// # Safety
314 ///
315 /// Offsets from the `ptr` where the mask evaluates to true must be dereferenceable to
316 /// the underlying scalar type. For implementations using the provided default, the
317 /// conversion from the bitmask to the actual mask must be correct.
318 #[inline(always)]
319 unsafe fn load_simd_masked(
320 arch: Self::Arch,
321 ptr: *const <Self as SIMDVector>::Scalar,
322 mask: <<Self as SIMDVector>::ConstLanes as BitMaskType<Self::Arch>>::Type,
323 ) -> Self {
324 // SAFETY: Bitmasks must be convertible to their corresponding logical mask.
325 // When the logical mask **is** a bitmask, this is a no-op.
326 unsafe { Self::load_simd_masked_logical(arch, ptr, mask.into()) }
327 }
328
329 /// The same as `load_simd_masked_logical`, but potentially specialized for situations
330 /// where it is known that some number of first elements will be accessed.
331 ///
332 /// If `first` is greater than or equal to the number of lanes, then all lanes will be
333 /// loaded.
334 ///
335 /// # Safety
336 ///
337 /// A contiguous read of `first.min(<Self as SIMDVector>::LANES)` must be valid.
338 #[inline(always)]
339 unsafe fn load_simd_first(
340 arch: Self::Arch,
341 ptr: *const <Self as SIMDVector>::Scalar,
342 first: usize,
343 ) -> Self {
344 // SAFETY: The implementation of `SIMDMask` must be correct.
345 unsafe {
346 Self::load_simd_masked_logical(
347 arch,
348 ptr,
349 <Self as SIMDVector>::Mask::keep_first(arch, first),
350 )
351 }
352 }
353
354 /// Store `<Self as SIMDVector>::LANES` number of elements contiguously starting at the
355 /// provided pointer.
356 ///
357 /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
358 /// not need to be stricter.
359 ///
360 /// # Safety
361 ///
362 /// The pointed-to memory must adhere to Rust's exclusive reference rules.
363 ///
364 /// A contiguous store of `<Self as SIMDVector>::LANES` must touch valid memory.
365 unsafe fn store_simd(self, ptr: *mut <Self as SIMDVector>::Scalar);
366
367 /// Store `<Self as SIMDVector>::LANES` number of elements starting at the provided
368 /// pointer.
369 ///
370 /// The alignment of `ptr` must be the same as `<Self as SIMDVector>::Scalar`, but does
371 /// not need to be stricter.
372 ///
373 /// Entries in the mask that evaluate to `false` will not be accessed.
374 /// This makes it safe to use this function with lanes masked out that would otherwise
375 /// cross a page boundary or otherwise cause an out-of-bounds write.
376 ///
377 /// # Safety
378 ///
379 /// The pointed-to memory must adhere to Rust's exclusive reference rules.
380 ///
381 /// Offsets from the `ptr` where the mask evaluates to true must be mutably
382 /// dereferenceable to the underlying scalar type.
383 unsafe fn store_simd_masked_logical(
384 self,
385 ptr: *mut <Self as SIMDVector>::Scalar,
386 mask: <Self as SIMDVector>::Mask,
387 );
388
389 /// The same as `store_simd_masked_logical` but taking a BitMask instead.
390 ///
391 /// No store attempt will be made to lanes that are masked out.
392 ///
393 /// # Safety
394 ///
395 /// The pointed-to memory must adhere to Rust's exclusive reference rules.
396 ///
397 /// Offsets from the `ptr` where the mask evaluates to true must be mutably
398 /// dereferenceable to the underlying scalar type.
399 ///
400 /// For implementations using the provided default, the conversion from the bitmask to
401 /// the actual mask must be correct.
402 #[inline(always)]
403 unsafe fn store_simd_masked(
404 self,
405 ptr: *mut <Self as SIMDVector>::Scalar,
406 mask: <<Self as SIMDVector>::ConstLanes as BitMaskType<Self::Arch>>::Type,
407 ) {
408 // SAFETY: Bitmasks must be convertible to their corresponding logical mask.
409 // When the logical mask **is** a bitmask, this is a no-op.
410 unsafe { self.store_simd_masked_logical(ptr, mask.into()) }
411 }
412
413 /// The same as `store_simd_masked_logical`, but potentially specialized for situations
414 /// where it is known that some number of first elements will be accessed.
415 ///
416 /// If `first` is greater than or equal to the number of lanes, then all lanes will be
417 /// written.
418 ///
419 /// # Safety
420 ///
421 /// The pointed-to memory must adhere to Rust's exclusive reference rules.
422 ///
423 /// A contiguous write of `first.min(<Self as SIMDVector>::LANES)` must be valid.
424 #[inline(always)]
425 unsafe fn store_simd_first(self, ptr: *mut <Self as SIMDVector>::Scalar, first: usize) {
426 // SAFETY: The implementation of `SIMDMask` must be correct.
427 unsafe {
428 self.store_simd_masked_logical(
429 ptr,
430 <Self as SIMDVector>::Mask::keep_first(self.arch(), first),
431 )
432 }
433 }
434
435 /// Perform a numeric cast on each element, returning a new SIMD vector.
436 ///
437 /// See also: [`SIMDCast`].
438 #[inline(always)]
439 fn cast<T>(self) -> <Self as SIMDCast<T>>::Cast
440 where
441 Self: SIMDCast<T>,
442 {
443 self.simd_cast()
444 }
445}
446
447/// Efficiently perform the operation
448/// ```ignore
449/// self * rhs + accumulator
450/// ```
451/// with the following semantics dependent on the associated scalar type.
452///
453/// * floating point: Perform a fused multiply-add, implementing the operation with only a
454/// single rounding instance.
455///
456/// * integer: Perform the multiplication followed by the accumulation. Both binary
457/// operations will be performed using wrap-around arithmetic.
458pub trait SIMDMulAdd {
459 fn mul_add_simd(self, rhs: Self, accumulator: Self) -> Self;
460}
461
462/// Efficiently retrieve the pairwise minimum or maximum for the two arguments.
463///
464/// Each function comes in two flavors:
465///
466/// * Standard (suffixed): Compute the minimum or maximum in a way that is equivalent to
467/// Rust's built-in minimum or maximum functions.
468///
469/// When the scalar type is integral, the behavior is unambiguous.
470///
471/// When the scalar type is a floating point and one value of a pair is NaN, the other
472/// value is returned. When the result is zero, either a positive or a negative zero can
473/// be returned.
474///
475/// * Fast (unsuffixed): Compute the minimum or maximum using the fastest possible method
476/// on the given architecture with non-standard NaN handling.
477///
478/// When the scalar type is integral, the behavior is the same as the standard
479/// implementations.
480///
481/// When the scalar type is a floating point, the implementation is allowed to differ
482/// with respect to NaN handling. That is, when one of the arguments is NaN, the
483/// implementation is allowed to return **either** the other argument (like the standard
484/// implementation) or NaN. Like the standard implementation, if the result is zero, then
485/// a zero of either sign can be returned.
486///
487/// This method should be preferred when precise NaN handling is not needed as it can be
488/// more efficient.
489pub trait SIMDMinMax: Sized {
490 /// Return the pairwise minimum of `self` and `rhs`, subject to looser NaN handling.
491 fn min_simd(self, rhs: Self) -> Self;
492
493 /// Return the pairwise minimum of `self` and `rhs` as if by applying the standard
494 /// library's `min` method for the scalar type.
495 #[inline(always)]
496 fn min_simd_standard(self, rhs: Self) -> Self {
497 self.min_simd(rhs)
498 }
499
500 /// Return the pairwise maximum of `self` and `rhs`, subject to looser NaN handling.
501 fn max_simd(self, rhs: Self) -> Self;
502
503 /// Return the pairwise maximum of `self` and `rhs` as if by applying the standard
504 /// library's `max` method for the scalar type.
505 #[inline(always)]
506 fn max_simd_standard(self, rhs: Self) -> Self {
507 self.max_simd(rhs)
508 }
509}
510
511/// Take the absolute value of each lane.
512///
513/// # Notes
514///
515/// For signed integer types T, this works as expected for all values except for `T::MIN`,
516/// in which case `T::MIN` is returned. This keeps the behavior in line with hardware
517/// intrinsics.
518///
519/// A correct answer can be retrieved by casting the result to the equivalent unsigned
520/// integer.
521pub trait SIMDAbs {
522 fn abs_simd(self) -> Self;
523}
524
525/// A SIMD equivalent of `std::cmp::PartialEq`.
526///
527/// Instead of a boolean, return `Self::Mask` containing the result of the element-wise
528/// comparison of the two vectors.
529pub trait SIMDPartialEq: SIMDVector {
530 /// SIMD equivalent of `std::cmp::PartialEq::eq`, applying the latter trait to each
531 /// lane-wise pair of elements in `self` and `other`.
532 fn eq_simd(self, other: Self) -> Self::Mask;
533
534 /// SIMD equivalent of `std::cmp::PartialEq::neq`, applying the latter trait to each
535 /// lane-wise pair of elements in `self` and `other`.
536 fn ne_simd(self, other: Self) -> Self::Mask;
537}
538
539/// A SIMD equivalent of `std::cmp::PartialOrd`.
540///
541/// Instead of a boolean, return `Self::Mask` containing the result of the element-wise
542/// comparisons of the two vectors.
543pub trait SIMDPartialOrd: SIMDVector {
544 /// SIMD equivalent of `std::cmp::PartialOrd::lt`.
545 fn lt_simd(self, other: Self) -> Self::Mask;
546
547 /// SIMD equivalent of `std::cmp::PartialOrd::le`.
548 fn le_simd(self, other: Self) -> Self::Mask;
549
550 //////////////////////
551 // Provided Methods //
552 //////////////////////
553
554 /// SIMD equivalent of `std::cmp::PartialOrd::gt`.
555 ///
556 /// Types are free to override the provided method if a more efficient implementation
557 /// is possible.
558 #[inline(always)]
559 fn gt_simd(self, other: Self) -> Self::Mask {
560 other.lt_simd(self)
561 }
562
563 /// SIMD equivalent of `std::cmp::PartialOrd::ge`.
564 ///
565 /// Types are free to override the provided method if a more efficient implementation
566 /// is possible.
567 #[inline(always)]
568 fn ge_simd(self, other: Self) -> Self::Mask {
569 other.le_simd(self)
570 }
571}
572
573/// Perform a pairwise reducing sum of all lanes in the vector and return the result as a
574/// scalar.
575///
576/// For example, the summing pattern for a vector of 8 elements is as follows:
577/// ```text
578/// let v0 = [x0, x1, x2, x3, x4, x5, x6, x7];
579/// let v1 = [v0[0] + v0[4], v0[1] + v0[5], v0[2] + v0[6], v0[3] + v0[7]];
580/// let v2 = [v1[0] + v1[2], v1[1] + v1[3]];
581/// v2[0] + v2[1]
582/// ```
583pub trait SIMDSumTree: SIMDVector {
584 fn sum_tree(self) -> <Self as SIMDVector>::Scalar;
585}
586
587/// A vectorized "if else".
588pub trait SIMDSelect<V: SIMDVector>: SIMDMask {
589 fn select(self, x: V, y: V) -> V;
590}
591
592/// Optimized dot-product style accumulation.
593///
594/// This tries to match against intrinsics like:
595///
596/// * `_mm256_madd_epi16`
597/// * `_mm256_dpbusd_epi32`
598///
599/// The gist is to perform element-wise multiplication between left and right, promoting the
600/// result to the element-type of `Self`, adding adjacent entries
601///
602/// # Precise Enumeration of Semantics for Implementations
603///
604/// The semantics depend on the source and destination type, but are intended to be the same
605/// for each type combination across architectures.
606///
607/// ## `SIMDDotProduct<i16x16> for i32x8`
608///
609/// 1. Perform multiplication as `i16x16 x i16x16` as if converting each lane to `i32`,
610/// resulting in effectively `i32x16`. No overflow can happen.
611/// 2. Add together adjacent pairs in the resulting `i32x16` to yield `i32x8`. Again, this
612/// step cannot overflow.
613/// 3. Add the resulting `i32x8` into `Self`, returning the result.
614///
615/// ## `SIMDDotProduct<u8x32, i8x32> for i32x8`
616///
617/// 1. Perform multiplication as `i32x32 x i32x32` as if converting each lane to `i32`,
618/// resulting in effectively `i32x32`. No overflow can happen.
619/// 2. Sum together consecutive groups of 4 in the resulting `i32x32` to yield `i32x8`.
620/// 3. Add the resulting `i32x8` into `Self`.
621///
622/// The same applies when the order of `u8x32` and `i8x32` are swapped and for types that
623/// are twice as wide.
624///
625/// The main goal of this function is to hit VNNI instructions like `_mm512_dpbusd_epi32`
626/// that can do the whole operation in a single go on the `V4` architecture. Use of this
627/// instruction is not recommended on non-`V4` architectures.
628pub trait SIMDDotProduct<L: SIMDVector, R: SIMDVector = L> {
629 /// Element wise multiply each component of `left` and `right`, promoting the
630 /// intermediate results to a higher precision.
631 ///
632 /// Then, horizontally add together groups of the accumulated values and add the
633 /// resulting sums to `self`.
634 ///
635 /// The size of the group depends on the relative number of lanes in `Self` and `Source`.
636 ///
637 /// However, it is required that `Self::num_lanes()` evenly divides `Source::num_lanes()`
638 /// so that the size of each group is uniform.
639 fn dot_simd(self, left: L, right: R) -> Self;
640}
641
642/// Perform a bit-cast from one SIMD type to another.
643pub trait SIMDReinterpret<To: SIMDVector>: SIMDVector {
644 fn reinterpret_simd(self) -> To;
645}
646
647/// Perform a numeric cast on the scalar type.
648///
649/// Unlike `From`, this conversion is allowed to be lossy, with similar semantics to
650/// numeric casts in scalar Rust.
651///
652/// This is meant to model Rust's numeric conversion with the "as" operator.
653pub trait SIMDCast<T>: SIMDVector {
654 /// The [`SIMDVector`] type of the result.
655 type Cast: SIMDVector<Scalar = T, ConstLanes = Self::ConstLanes>;
656 /// Perform the cast.
657 fn simd_cast(self) -> Self::Cast;
658}
659
660/// A roll-up of traits required for SIMD floating point types.
661pub trait SIMDFloat:
662 SIMDVector
663 + std::ops::Add<Output = Self>
664 + std::ops::Mul<Output = Self>
665 + std::ops::Sub<Output = Self>
666 + SIMDMulAdd
667 + SIMDMinMax
668 + SIMDPartialEq
669 + SIMDPartialOrd
670{
671}
672
673impl<T> SIMDFloat for T where
674 T: SIMDVector
675 + std::ops::Add<Output = Self>
676 + std::ops::Mul<Output = Self>
677 + std::ops::Sub<Output = Self>
678 + SIMDMulAdd
679 + SIMDMinMax
680 + SIMDPartialEq
681 + SIMDPartialOrd
682{
683}
684
685/// A roll-up of traits required for SIMD integer types.
686pub trait SIMDUnsigned:
687 SIMDVector
688 + std::ops::Add<Output = Self>
689 + std::ops::Mul<Output = Self>
690 + std::ops::Sub<Output = Self>
691 + std::ops::BitAnd<Output = Self>
692 + std::ops::BitOr<Output = Self>
693 + std::ops::BitXor<Output = Self>
694 + std::ops::Shr<Output = Self>
695 + std::ops::Shl<Output = Self>
696 + std::ops::Shr<Self::Scalar, Output = Self>
697 + std::ops::Shl<Self::Scalar, Output = Self>
698 + SIMDMulAdd
699 + SIMDPartialEq
700 + SIMDPartialOrd
701{
702}
703
704impl<T> SIMDUnsigned for T where
705 T: SIMDVector
706 + std::ops::Add<Output = Self>
707 + std::ops::Mul<Output = Self>
708 + std::ops::Sub<Output = Self>
709 + std::ops::BitAnd<Output = Self>
710 + std::ops::BitOr<Output = Self>
711 + std::ops::BitXor<Output = Self>
712 + std::ops::Shr<Output = Self>
713 + std::ops::Shl<Output = Self>
714 + std::ops::Shr<Self::Scalar, Output = Self>
715 + std::ops::Shl<Self::Scalar, Output = Self>
716 + SIMDMulAdd
717 + SIMDPartialEq
718 + SIMDPartialOrd
719{
720}
721
722pub trait SIMDSigned: SIMDUnsigned + SIMDAbs {}
723impl<T> SIMDSigned for T where T: SIMDUnsigned + SIMDAbs {}
724
725/// Element-wise zip and unzip of two half-width SIMD vectors.
726///
727/// `zip` interleaves elements from two halves into one full-width vector:
728/// `zip([a0, a1, …], [b0, b1, …]) = [a0, b0, a1, b1, …]`
729///
730/// `unzip` is the inverse, separating even- and odd-indexed elements:
731/// `unzip([a0, b0, a1, b1, …]) = ([a0, a1, …], [b0, b1, …])`
732///
733/// Two additional "flat" methods operate on a single full-width register:
734///
735/// `zip_flat` treats the low half of `self` as one input and the high half as the other,
736/// interleaving them in-place:
737/// `zip_flat([a0, a1, …, b0, b1, …]) = [a0, b0, a1, b1, …]`
738///
739/// `unzip_flat` is the inverse, collecting even-indexed elements into the low half and
740/// odd-indexed elements into the high half:
741/// `unzip_flat([a0, b0, a1, b1, …]) = [a0, a1, …, b0, b1, …]`
742///
743/// `zip` and `unzip` are required. `zip_flat` and `unzip_flat` have default
744/// implementations that delegate through [`SplitJoin::split`] / [`SplitJoin::join`].
745/// Backends should override whichever pair is cheapest on their ISA.
746pub trait ZipUnzip: crate::SplitJoin + Sized {
747 /// Interleave elements from `halves.lo` and `halves.hi` into `Self`.
748 fn zip(halves: crate::LoHi<Self::Halved>) -> Self;
749
750 /// Separate even-indexed elements into `lo` and odd-indexed into `hi`.
751 fn unzip(self) -> crate::LoHi<Self::Halved>;
752
753 /// Interleave in-place: the low half of `self` supplies the even-indexed
754 /// positions and the high half supplies the odd-indexed positions.
755 ///
756 /// Equivalent to `Self::zip(self.split())` but may be implemented with a
757 /// single cross-lane permute on architectures that support it.
758 fn zip_flat(self) -> Self {
759 Self::zip(self.split())
760 }
761
762 /// Deinterleave in-place: even-indexed elements are collected into the low
763 /// half of the result and odd-indexed elements into the high half.
764 ///
765 /// Equivalent to `Self::join(self.unzip())` but may be implemented with a
766 /// single cross-lane permute on architectures that support it.
767 fn unzip_flat(self) -> Self {
768 Self::unzip(self).join()
769 }
770}
771
772// Since it is so difficult to work directly with generic integers, resort to using a macro
773// to stamp out implementations of `SIMDMask` for `BitMask`.
774//
775// The argument `submask` is a bit-pattern to apply to the underlying type and is to mask
776// out upper-bits of the representation for 2 and 4 bit masks.
777macro_rules! impl_simd_mask_for_bitmask {
778 ($N:literal, $repr:ty, $submask:expr) => {
779 impl<A: arch::Sealed> SIMDMask for BitMask<$N, A> {
780 type Arch = A;
781 type Underlying = $repr;
782 type BitMask = Self;
783 const ISBITS: bool = true;
784 const LANES: usize = $N;
785
786 #[inline(always)]
787 fn arch(self) -> A {
788 self.get_arch()
789 }
790
791 #[inline(always)]
792 fn to_underlying(self) -> Self::Underlying {
793 self.0
794 }
795
796 #[inline(always)]
797 fn from_underlying(arch: A, value: Self::Underlying) -> Self {
798 Self::from_int(arch, value)
799 }
800
801 #[inline(always)]
802 fn keep_first(arch: A, i: usize) -> Self {
803 // Ensure that providing a value that is too big still yields sensible
804 // results.
805 let i = i.min(Self::LANES);
806
807 // Handle 64-bit integers properly.
808 // It is expected that the compiler will be able to optimize out this branch
809 // for non-64-bit types.
810 if Self::LANES == 64 && i == 64 {
811 return Self::from_underlying(arch, Self::Underlying::MAX);
812 }
813
814 let one: u64 = 1;
815 // "as" conversion in Rust performs truncation on the integers.
816 Self::from_underlying(arch, ((one << i) - one) as Self::Underlying)
817 }
818
819 #[inline(always)]
820 fn get_unchecked(&self, i: usize) -> bool {
821 if i >= Self::LANES {
822 false
823 } else {
824 (self.0 >> i) % 2 == 1
825 }
826 }
827
828 #[inline(always)]
829 fn first(&self) -> Option<usize> {
830 let count = self.0.trailing_zeros() as usize;
831 if count >= Self::LANES {
832 None
833 } else {
834 Some(count)
835 }
836 }
837
838 // End of recursion functions
839 fn from_fn<F>(arch: A, mut f: F) -> Self
840 where
841 F: FnMut(usize) -> bool,
842 {
843 let mut x: $repr = 0;
844 for i in 0..Self::LANES {
845 if f(i) {
846 x |= (1 << i);
847 }
848 }
849 Self::from_underlying(arch, x)
850 }
851
852 #[inline(always)]
853 fn any(self) -> bool {
854 self.0 != 0
855 }
856
857 #[inline(always)]
858 fn all(self) -> bool {
859 let v: u64 = self.0.into();
860
861 // We again need to handle 64-wide masks differently.
862 if $N == 64 {
863 v == u64::MAX
864 } else {
865 v == (1 << $N) - 1
866 }
867 }
868
869 #[inline(always)]
870 fn count(self) -> usize {
871 // We keep the invariant that all constructors of `BitMask` must zero out
872 // upper bits (for 2 and 4 width BitMasks).
873 self.0.count_ones() as usize
874 }
875 }
876
877 // Transformation from bitmask to integer.
878 impl From<BitMask<$N>> for $repr {
879 fn from(value: BitMask<$N>) -> Self {
880 value.to_underlying()
881 }
882 }
883 };
884}
885
886// Stamp out a bunch of implementations.
887impl_simd_mask_for_bitmask!(1, u8, 0x1);
888impl_simd_mask_for_bitmask!(2, u8, 0x3);
889impl_simd_mask_for_bitmask!(4, u8, 0xf);
890impl_simd_mask_for_bitmask!(8, u8, u8::MAX);
891impl_simd_mask_for_bitmask!(16, u16, u16::MAX);
892impl_simd_mask_for_bitmask!(32, u32, u32::MAX);
893impl_simd_mask_for_bitmask!(64, u64, u64::MAX);
894
895#[cfg(test)]
896mod test_traits {
897 use rand::{
898 SeedableRng,
899 distr::{Distribution, StandardUniform},
900 rngs::StdRng,
901 };
902
903 use super::*;
904 use crate::{
905 ARCH, arch,
906 splitjoin::{LoHi, SplitJoin},
907 test_utils,
908 };
909
910 // Allow unsigned 128-bit integers to be converted to narrow types.
911 trait FromU128 {
912 fn from_(value: u128) -> Self;
913 }
914
915 impl FromU128 for u8 {
916 fn from_(value: u128) -> Self {
917 value as u8
918 }
919 }
920 impl FromU128 for u16 {
921 fn from_(value: u128) -> Self {
922 value as u16
923 }
924 }
925 impl FromU128 for u32 {
926 fn from_(value: u128) -> Self {
927 value as u32
928 }
929 }
930 impl FromU128 for u64 {
931 fn from_(value: u128) -> Self {
932 value as u64
933 }
934 }
935
936 /// Test that bitmasks faithfully implement the trait `SIMDMask`.
937 ///
938 /// Since conversion between `u128` and arbitrary generic parameters `T` are now
939 /// allowed, we take a conversion function to do this for us with all the known type
940 /// information.
941 fn test_bitmask_impl<const N: usize, T>()
942 where
943 Const<N>: SupportedLaneCount, // this value of `N` has a bitmask representation.
944 T: std::fmt::Debug + std::cmp::Eq + FromU128 + From<BitMask<N, arch::Current>>,
945 BitMask<N, arch::Current>: SIMDMask<Arch = arch::Current, Underlying = T>,
946 {
947 const MAXLEN: usize = 64;
948 assert_eq!(N, BitMask::<N, arch::Current>::LANES);
949
950 // The bit-mask corresponding to all lanes.
951 let one = 1_u128;
952
953 let all: u128 = (one << N) - one;
954
955 for i in 0..=MAXLEN {
956 let mask = BitMask::<N, arch::Current>::keep_first(arch::current(), i);
957
958 let expected: u128 = ((one << i) - one) & all;
959
960 // Cannot use "as" since T is not known to be a primitive type ...
961 assert_eq!(mask.to_underlying(), T::from_(expected));
962 assert_eq!(T::from_(expected), mask.into());
963 for j in 0..=MAXLEN {
964 let b = mask.get_unchecked(j);
965 let o = mask.get(j);
966
967 let expected: bool = j < i;
968 if j < N {
969 assert_eq!(b, expected);
970 assert_eq!(o.unwrap(), expected);
971 } else {
972 assert!(!b);
973 assert!(o.is_none());
974 }
975 }
976
977 // Check reductions.
978 if i == 0 {
979 assert!(!mask.any());
980 assert!(!mask.all());
981 assert!(mask.none());
982 } else if i >= N {
983 assert!(mask.any());
984 assert!(mask.all());
985 assert!(!mask.none());
986 } else {
987 assert!(mask.any());
988 assert!(!mask.all());
989 assert!(!mask.none());
990 }
991 }
992 }
993
994 #[test]
995 fn test_bitmask() {
996 test_bitmask_impl::<1, u8>();
997 test_bitmask_impl::<2, u8>();
998 test_bitmask_impl::<4, u8>();
999 test_bitmask_impl::<8, u8>();
1000 test_bitmask_impl::<16, u16>();
1001 test_bitmask_impl::<32, u32>();
1002 test_bitmask_impl::<64, u64>();
1003 }
1004
1005 fn test_bitmask_splitjoin_impl<const N: usize, const NHALF: usize>(ntrials: usize, seed: u64)
1006 where
1007 Const<N>: SupportedLaneCount,
1008 Const<NHALF>: SupportedLaneCount,
1009 BitMask<N, arch::Current>:
1010 SIMDMask<Arch = arch::Current> + SplitJoin<Halved = BitMask<NHALF, arch::Current>>,
1011 BitMask<NHALF, arch::Current>: SIMDMask<Arch = arch::Current>,
1012 {
1013 let mut rng = StdRng::seed_from_u64(seed);
1014 for _ in 0..ntrials {
1015 let base = BitMask::<N>::from_fn(ARCH, |_| StandardUniform {}.sample(&mut rng));
1016 let LoHi { lo, hi } = base.split();
1017
1018 for i in 0..NHALF {
1019 assert_eq!(base.get(i).unwrap(), lo.get(i).unwrap());
1020 }
1021
1022 for i in 0..NHALF {
1023 assert_eq!(base.get(i + NHALF).unwrap(), hi.get(i).unwrap());
1024 }
1025
1026 let joined = BitMask::<N>::join(LoHi::new(lo, hi));
1027 bitmasks_equal(base, joined);
1028 }
1029 }
1030
1031 #[test]
1032 fn test_bitmask_splitjoin() {
1033 test_bitmask_splitjoin_impl::<2, 1>(100, 0xcbdbdca310caec88);
1034 test_bitmask_splitjoin_impl::<4, 2>(100, 0x9c8b9b6c70d941c5);
1035 test_bitmask_splitjoin_impl::<8, 4>(100, 0xc81a25918b683d39);
1036 test_bitmask_splitjoin_impl::<16, 8>(50, 0xad045b437c3fa0cc);
1037 test_bitmask_splitjoin_impl::<32, 16>(50, 0xe710ccdbbd329c77);
1038 test_bitmask_splitjoin_impl::<64, 32>(25, 0xd6697e3c534fc134);
1039 }
1040
1041 // Explicit tests to ensure that upper bits are masked out during construction of
1042 // 2 and 4 bit masks.
1043 #[test]
1044 fn test_zeroing() {
1045 let b = BitMask::<2>::from_underlying(arch::current(), 0xff);
1046 assert_eq!(b.to_underlying(), 0x3);
1047 assert_eq!(b.count(), 2);
1048
1049 let b = BitMask::<4>::from_underlying(arch::current(), 0xff);
1050 assert_eq!(b.to_underlying(), 0xf);
1051 assert_eq!(b.count(), 4);
1052 }
1053
1054 fn bitmasks_equal<const N: usize>(x: BitMask<N, arch::Current>, y: BitMask<N, arch::Current>)
1055 where
1056 Const<N>: SupportedLaneCount,
1057 BitMask<N, arch::Current>: SIMDMask,
1058 {
1059 assert_eq!(x.0, y.0);
1060 }
1061
1062 // A helper macro to run a BitMask through the SIMDMask test routines.
1063 macro_rules! test_simdmask {
1064 ($N:literal) => {
1065 paste::paste! {
1066 #[test]
1067 fn [<test_simd_mask_ $N>]() {
1068 let arch = arch::current();
1069 test_utils::mask::test_keep_first::<BitMask<$N, arch::Current>, $N, _, _>(
1070 arch,
1071 bitmasks_equal
1072 );
1073 test_utils::mask::test_from_fn::<BitMask<$N, arch::Current>, $N, _, _>(
1074 arch,
1075 bitmasks_equal
1076 );
1077 test_utils::mask::test_reductions::<BitMask<$N, arch::Current>, $N, _, _>(
1078 arch,
1079 bitmasks_equal
1080 );
1081 test_utils::mask::test_first::<BitMask<$N, arch::Current>, $N, _, _>(
1082 arch,
1083 bitmasks_equal
1084 );
1085 }
1086 }
1087 };
1088 }
1089
1090 test_simdmask!(2);
1091 test_simdmask!(4);
1092 test_simdmask!(8);
1093 test_simdmask!(16);
1094 test_simdmask!(32);
1095 test_simdmask!(64);
1096}