Skip to main content

diskann_wide/
emulated.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use half::f16;
7
8use super::{
9    SplitJoin, SupportedLaneCount, ZipUnzip,
10    arch::{self, emulated::Scalar},
11    bitmask::BitMask,
12    constant::Const,
13    reference::{ReferenceAbs, ReferenceCast, ReferenceScalarOps, ReferenceShifts, TreeReduce},
14    traits::{
15        ArrayType, SIMDAbs, SIMDCast, SIMDDotProduct, SIMDMask, SIMDMinMax, SIMDMulAdd,
16        SIMDPartialEq, SIMDPartialOrd, SIMDReinterpret, SIMDSelect, SIMDSumTree, SIMDVector,
17    },
18};
19
20/// An emulated SIMD vector.
21///
22/// The emulated implementation behaves just like an intrinsic, but the APIs are implemented
23/// using loops over arrays rather than dispatching to platform specific instructions.
24///
25/// The idea behind this type is that it can be used on architecture where explicit backend
26/// support has not been added, or when an architecture does not support a given type/length
27/// pair well.
28///
29/// Furthermore, it can be used when developing new back-ends to provide fallback
30/// implementations. This allows new back-ends to be developed one piece at a time instead
31/// of all at once.
32///
33/// NOTE: The alignment requirements of an emulated vector *will* be different than the
34/// alignment requirements of an actual intrinsic.
35///
36/// Higher level code *must not* rely on alignments being compatible across architectures!
37#[derive(Debug, Clone, Copy)]
38pub struct Emulated<T, const N: usize, A = Scalar>(pub(crate) [T; N], A);
39
40impl<T, const N: usize, A> Emulated<T, N, A> {
41    pub fn from_arch_fn<F>(arch: A, f: F) -> Self
42    where
43        F: FnMut(usize) -> T,
44    {
45        Self(core::array::from_fn(f), arch)
46    }
47}
48
49impl<T, const N: usize, A> SIMDVector for Emulated<T, N, A>
50where
51    T: Copy + std::fmt::Debug + Default,
52    Const<N>: ArrayType<T, Type = [T; N]>,
53    BitMask<N, A>: SIMDMask<Arch = A>,
54    A: arch::Sealed,
55{
56    type Arch = A;
57    type Scalar = T;
58    type Underlying = [T; N];
59    type ConstLanes = Const<N>;
60    const LANES: usize = N;
61    type Mask = BitMask<N, A>;
62
63    /// The underlying behavior is emulated using loops and is not accelerated by back-end
64    /// intrinsics.
65    const EMULATED: bool = true;
66
67    /// Return the Scalar architecture.
68    fn arch(self) -> A {
69        self.1
70    }
71
72    fn default(arch: A) -> Self {
73        Self([T::default(); N], arch)
74    }
75
76    /// Return the underlying array.
77    fn to_underlying(self) -> Self::Underlying {
78        self.0
79    }
80
81    /// Construct from the underlying array.
82    fn from_underlying(arch: A, repr: [T; N]) -> Self {
83        Self(repr, arch)
84    }
85
86    /// Return the underlying array.
87    fn to_array(self) -> [T; N] {
88        self.0
89    }
90
91    /// Construct from the underlying array.
92    fn from_array(arch: A, x: [T; N]) -> Self {
93        Self(x, arch)
94    }
95
96    /// Broadcast the provided scalar across all lanes.
97    fn splat(arch: A, value: Self::Scalar) -> Self {
98        Self([value; N], arch)
99    }
100
101    /// Load all the things.
102    #[inline(always)]
103    unsafe fn load_simd(arch: A, ptr: *const T) -> Self {
104        // SAFETY: The caller asserts that `ptr` is contiguously readable for `N` values.
105        Self(
106            unsafe { std::ptr::read_unaligned(ptr.cast::<[T; N]>()) },
107            arch,
108        )
109    }
110
111    /// Only load values when the corresponding mask lane is set.
112    unsafe fn load_simd_masked_logical(arch: A, ptr: *const T, mask: Self::Mask) -> Self {
113        Self::from_arch_fn(arch, |i| {
114            if mask.get_unchecked(i) {
115                // SAFETY: The caller ensures it's safe to access this offset from `ptr`
116                // because the lane in `mask` is set.
117                unsafe { std::ptr::read_unaligned(ptr.add(i)) }
118            } else {
119                T::default()
120            }
121        })
122    }
123
124    /// Only load the first `first` items. Set the rest to zero.
125    #[inline(always)]
126    unsafe fn load_simd_first(arch: A, ptr: *const T, first: usize) -> Self {
127        Self::from_arch_fn(arch, |i| {
128            if i < first {
129                // SAFETY: The caller ensures it's safe to access the first `first` values
130                // beginning at `ptr`.
131                unsafe { std::ptr::read_unaligned(ptr.add(i)) }
132            } else {
133                T::default()
134            }
135        })
136    }
137
138    /// Store all the things.
139    #[inline(always)]
140    unsafe fn store_simd(self, ptr: *mut T) {
141        // SAFETY: The caller asserts that it is safe to write `N` contiguous values to `ptr`.
142        unsafe { ptr.cast::<[T; N]>().write_unaligned(self.0) }
143    }
144
145    /// Only store values when the corresponding mask lane is set.
146    unsafe fn store_simd_masked_logical(self, ptr: *mut T, mask: Self::Mask) {
147        for (i, v) in self.0.iter().enumerate() {
148            if mask.get_unchecked(i) {
149                // SAFETY: The caller asserts it is safe to write to offsets with the
150                // corresponding bit mask set.
151                unsafe { ptr.add(i).write_unaligned(*v) };
152            }
153        }
154    }
155
156    /// Only store the first `first` items. Set the rest to zero.
157    #[inline(always)]
158    unsafe fn store_simd_first(self, ptr: *mut T, first: usize) {
159        for (i, v) in self.0.iter().enumerate().take(first) {
160            // SAFETY: The caller asserts it is safe to write to the first `first` offsets
161            // beginning at `ptr`.
162            unsafe { ptr.add(i).write_unaligned(*v) };
163        }
164    }
165}
166
167/// Binary Ops
168impl<T, const N: usize, A> std::ops::Add for Emulated<T, N, A>
169where
170    T: ReferenceScalarOps + Copy + std::fmt::Debug + std::default::Default,
171    Const<N>: ArrayType<T>,
172{
173    type Output = Self;
174    fn add(self, rhs: Self) -> Self {
175        Self::from_arch_fn(self.1, |i| self.0[i].expected_add_(rhs.0[i]))
176    }
177}
178
179impl<T, const N: usize, A> std::ops::Sub for Emulated<T, N, A>
180where
181    T: ReferenceScalarOps,
182{
183    type Output = Self;
184
185    #[inline(always)]
186    fn sub(self, rhs: Self) -> Self {
187        Self::from_arch_fn(self.1, |i| self.0[i].expected_sub_(rhs.0[i]))
188    }
189}
190
191impl<T, const N: usize, A> std::ops::Mul for Emulated<T, N, A>
192where
193    T: ReferenceScalarOps,
194{
195    type Output = Self;
196    fn mul(self, rhs: Self) -> Self {
197        Self::from_arch_fn(self.1, |i| self.0[i].expected_mul_(rhs.0[i]))
198    }
199}
200
201/// MulAdd
202impl<T, const N: usize, A> SIMDMulAdd for Emulated<T, N, A>
203where
204    T: ReferenceScalarOps,
205{
206    #[inline(always)]
207    fn mul_add_simd(self, rhs: Self, accumulator: Self) -> Self {
208        Self::from_arch_fn(self.1, |i| {
209            self.0[i].expected_fma_(rhs.0[i], accumulator.0[i])
210        })
211    }
212}
213
214/// MinMax
215impl<T, const N: usize, A> SIMDMinMax for Emulated<T, N, A>
216where
217    T: ReferenceScalarOps,
218{
219    #[inline(always)]
220    fn min_simd(self, rhs: Self) -> Self {
221        Self::from_arch_fn(self.1, |i| self.0[i].expected_min_(rhs.0[i]))
222    }
223    #[inline(always)]
224    fn max_simd(self, rhs: Self) -> Self {
225        Self::from_arch_fn(self.1, |i| self.0[i].expected_max_(rhs.0[i]))
226    }
227}
228
229/// Abs
230impl<T, const N: usize, A> SIMDAbs for Emulated<T, N, A>
231where
232    T: ReferenceAbs,
233{
234    #[inline(always)]
235    fn abs_simd(self) -> Self {
236        Self::from_arch_fn(self.1, |i| self.0[i].expected_abs_())
237    }
238}
239
240/// SIMDPartialEq
241impl<T, const N: usize, A> SIMDPartialEq for Emulated<T, N, A>
242where
243    T: PartialEq,
244    Self: SIMDVector,
245{
246    #[inline(always)]
247    fn eq_simd(self, other: Self) -> Self::Mask {
248        Self::Mask::from_fn(self.arch(), |i| self.0[i] == other.0[i])
249    }
250
251    #[inline(always)]
252    fn ne_simd(self, other: Self) -> Self::Mask {
253        Self::Mask::from_fn(self.arch(), |i| self.0[i] != other.0[i])
254    }
255}
256
257/// SIMDPartialOrd
258impl<T, const N: usize, A> SIMDPartialOrd for Emulated<T, N, A>
259where
260    T: PartialOrd,
261    Self: SIMDVector,
262{
263    #[inline(always)]
264    fn lt_simd(self, other: Self) -> Self::Mask {
265        Self::Mask::from_fn(self.arch(), |i| self.0[i] < other.0[i])
266    }
267
268    #[inline(always)]
269    fn le_simd(self, other: Self) -> Self::Mask {
270        Self::Mask::from_fn(self.arch(), |i| self.0[i] <= other.0[i])
271    }
272
273    #[inline(always)]
274    fn gt_simd(self, other: Self) -> Self::Mask {
275        Self::Mask::from_fn(self.arch(), |i| self.0[i] > other.0[i])
276    }
277
278    #[inline(always)]
279    fn ge_simd(self, other: Self) -> Self::Mask {
280        Self::Mask::from_fn(self.arch(), |i| self.0[i] >= other.0[i])
281    }
282}
283
284// Bit Ops
285impl<T, const N: usize, A> std::ops::BitAnd for Emulated<T, N, A>
286where
287    T: std::ops::BitAnd<Output = T> + Copy,
288{
289    type Output = Self;
290    #[inline(always)]
291    fn bitand(self, other: Self) -> Self::Output {
292        Self::from_arch_fn(self.1, |i| self.0[i] & other.0[i])
293    }
294}
295
296impl<T, const N: usize, A> std::ops::BitOr for Emulated<T, N, A>
297where
298    T: std::ops::BitOr<Output = T> + Copy,
299{
300    type Output = Self;
301    #[inline(always)]
302    fn bitor(self, other: Self) -> Self::Output {
303        Self::from_arch_fn(self.1, |i| self.0[i] | other.0[i])
304    }
305}
306
307impl<T, const N: usize, A> std::ops::BitXor for Emulated<T, N, A>
308where
309    T: std::ops::BitXor<Output = T> + Copy,
310{
311    type Output = Self;
312    #[inline(always)]
313    fn bitxor(self, other: Self) -> Self::Output {
314        Self::from_arch_fn(self.1, |i| self.0[i] ^ other.0[i])
315    }
316}
317
318impl<T, const N: usize, A> std::ops::Not for Emulated<T, N, A>
319where
320    T: std::ops::Not<Output = T> + Copy,
321{
322    type Output = Self;
323    #[inline(always)]
324    fn not(self) -> Self::Output {
325        Self::from_arch_fn(self.1, |i| !self.0[i])
326    }
327}
328
329impl<T, const N: usize, A> std::ops::Shl for Emulated<T, N, A>
330where
331    T: ReferenceShifts,
332{
333    type Output = Self;
334    #[inline(always)]
335    fn shl(self, rhs: Self) -> Self::Output {
336        Self::from_arch_fn(self.1, |i| self.0[i].expected_shl_(rhs.0[i]))
337    }
338}
339
340impl<T, const N: usize, A> std::ops::Shl<T> for Emulated<T, N, A>
341where
342    T: ReferenceShifts,
343{
344    type Output = Self;
345    #[inline(always)]
346    fn shl(self, rhs: T) -> Self::Output {
347        Self::from_arch_fn(self.1, |i| self.0[i].expected_shl_(rhs))
348    }
349}
350
351impl<T, const N: usize, A> std::ops::Shr for Emulated<T, N, A>
352where
353    T: ReferenceShifts,
354{
355    type Output = Self;
356    #[inline(always)]
357    fn shr(self, rhs: Self) -> Self::Output {
358        Self::from_arch_fn(self.1, |i| self.0[i].expected_shr_(rhs.0[i]))
359    }
360}
361
362impl<T, const N: usize, A> std::ops::Shr<T> for Emulated<T, N, A>
363where
364    T: ReferenceShifts,
365{
366    type Output = Self;
367    #[inline(always)]
368    fn shr(self, rhs: T) -> Self::Output {
369        Self::from_arch_fn(self.1, |i| self.0[i].expected_shr_(rhs))
370    }
371}
372
373//////////////////
374// Dot Products //
375//////////////////
376
377// i16 to i32
378macro_rules! impl_simd_dot_product_i16_to_i32 {
379    ($N:literal, $TwoN:literal) => {
380        /// Promote intermediate values to `i32` and then perform accumulation.
381        impl<A> SIMDDotProduct<Emulated<i16, $TwoN, A>> for Emulated<i32, $N, A>
382        where
383            A: arch::Sealed,
384        {
385            fn dot_simd(
386                self,
387                left: Emulated<i16, $TwoN, A>,
388                right: Emulated<i16, $TwoN, A>,
389            ) -> Self {
390                self + Self::from_arch_fn(self.1, |i| {
391                    let l0: i32 = left.0[2 * i].into();
392                    let l1: i32 = left.0[2 * i + 1].into();
393
394                    let r0: i32 = right.0[2 * i].into();
395                    let r1: i32 = right.0[2 * i + 1].into();
396                    l0.expected_fma_(r0, l1.expected_mul_(r1))
397                })
398            }
399        }
400    };
401}
402
403//i8/u8 to i32
404macro_rules! impl_simd_dot_product_iu8_to_i32 {
405    ($N:literal, $TwoN:literal) => {
406        /// Promote intermediate values to `i32` and then perform accumulation.
407        impl<A> SIMDDotProduct<Emulated<u8, $TwoN, A>, Emulated<i8, $TwoN, A>>
408            for Emulated<i32, $N, A>
409        where
410            A: arch::Sealed,
411        {
412            fn dot_simd(self, left: Emulated<u8, $TwoN, A>, right: Emulated<i8, $TwoN, A>) -> Self {
413                self + Self::from_arch_fn(self.1, |i| {
414                    let l0: i32 = left.0[4 * i].into();
415                    let l1: i32 = left.0[4 * i + 1].into();
416                    let l2: i32 = left.0[4 * i + 2].into();
417                    let l3: i32 = left.0[4 * i + 3].into();
418
419                    let r0: i32 = right.0[4 * i].into();
420                    let r1: i32 = right.0[4 * i + 1].into();
421                    let r2: i32 = right.0[4 * i + 2].into();
422                    let r3: i32 = right.0[4 * i + 3].into();
423
424                    let a = l0.expected_fma_(r0, l1.expected_mul_(r1));
425                    let b = l2.expected_fma_(r2, l3.expected_mul_(r3));
426                    a + b
427                })
428            }
429        }
430
431        impl<A> SIMDDotProduct<Emulated<i8, $TwoN, A>, Emulated<u8, $TwoN, A>>
432            for Emulated<i32, $N, A>
433        where
434            A: arch::Sealed,
435        {
436            fn dot_simd(self, left: Emulated<i8, $TwoN, A>, right: Emulated<u8, $TwoN, A>) -> Self {
437                self.dot_simd(right, left)
438            }
439        }
440
441        impl<A> SIMDDotProduct<Emulated<u8, $TwoN, A>, Emulated<u8, $TwoN, A>>
442            for Emulated<u32, $N, A>
443        where
444            A: arch::Sealed,
445        {
446            fn dot_simd(self, left: Emulated<u8, $TwoN, A>, right: Emulated<u8, $TwoN, A>) -> Self {
447                self + Self::from_arch_fn(self.1, |i| {
448                    let l0: u32 = left.0[4 * i].into();
449                    let l1: u32 = left.0[4 * i + 1].into();
450                    let l2: u32 = left.0[4 * i + 2].into();
451                    let l3: u32 = left.0[4 * i + 3].into();
452
453                    let r0: u32 = right.0[4 * i].into();
454                    let r1: u32 = right.0[4 * i + 1].into();
455                    let r2: u32 = right.0[4 * i + 2].into();
456                    let r3: u32 = right.0[4 * i + 3].into();
457
458                    let a = l0.expected_fma_(r0, l1.expected_mul_(r1));
459                    let b = l2.expected_fma_(r2, l3.expected_mul_(r3));
460                    a + b
461                })
462            }
463        }
464
465        impl<A> SIMDDotProduct<Emulated<i8, $TwoN, A>, Emulated<i8, $TwoN, A>>
466            for Emulated<i32, $N, A>
467        where
468            A: arch::Sealed,
469        {
470            fn dot_simd(self, left: Emulated<i8, $TwoN, A>, right: Emulated<i8, $TwoN, A>) -> Self {
471                self + Self::from_arch_fn(self.1, |i| {
472                    let l0: i32 = left.0[4 * i].into();
473                    let l1: i32 = left.0[4 * i + 1].into();
474                    let l2: i32 = left.0[4 * i + 2].into();
475                    let l3: i32 = left.0[4 * i + 3].into();
476
477                    let r0: i32 = right.0[4 * i].into();
478                    let r1: i32 = right.0[4 * i + 1].into();
479                    let r2: i32 = right.0[4 * i + 2].into();
480                    let r3: i32 = right.0[4 * i + 3].into();
481
482                    let a = l0.expected_fma_(r0, l1.expected_mul_(r1));
483                    let b = l2.expected_fma_(r2, l3.expected_mul_(r3));
484                    a + b
485                })
486            }
487        }
488    };
489}
490
491impl_simd_dot_product_i16_to_i32!(4, 8);
492impl_simd_dot_product_i16_to_i32!(8, 16);
493impl_simd_dot_product_i16_to_i32!(16, 32);
494
495impl_simd_dot_product_iu8_to_i32!(4, 16);
496impl_simd_dot_product_iu8_to_i32!(8, 32);
497impl_simd_dot_product_iu8_to_i32!(16, 64);
498
499////////////
500// Select //
501////////////
502
503impl<T, const N: usize, A> SIMDSelect<Emulated<T, N, A>> for BitMask<N, A>
504where
505    T: Copy,
506    A: arch::Sealed,
507    Const<N>: SupportedLaneCount,
508    BitMask<N, A>: SIMDMask<Arch = A>,
509    Emulated<T, N, A>: SIMDVector<Mask = BitMask<N, A>>,
510{
511    #[inline(always)]
512    fn select(self, x: Emulated<T, N, A>, y: Emulated<T, N, A>) -> Emulated<T, N, A> {
513        Emulated::from_arch_fn(self.arch(), |i| {
514            if self.get_unchecked(i) {
515                x.0[i]
516            } else {
517                y.0[i]
518            }
519        })
520    }
521}
522
523/////////////
524// SumTree //
525/////////////
526
527macro_rules! impl_sumtree {
528    ($T:ty, $N:literal) => {
529        impl<A> SIMDSumTree for Emulated<$T, $N, A>
530        where
531            A: arch::Sealed,
532        {
533            #[inline(always)]
534            fn sum_tree(self) -> $T {
535                self.0.tree_reduce(|x, y| x.expected_add_(y))
536            }
537        }
538    };
539    ($T:ty, $($N:literal),* $(,)?) => {
540        $(impl_sumtree!($T, $N);)*
541    };
542}
543
544impl_sumtree!(f32, 1, 2, 4, 8, 16);
545impl_sumtree!(i32, 4, 8, 16);
546impl_sumtree!(u32, 4, 8, 16);
547
548////////////////
549// Conversion //
550////////////////
551
552macro_rules! impl_from {
553    (f16 => f32, $N:literal) => {
554        impl<A> From<Emulated<f16, $N, A>> for Emulated<f32, $N, A> {
555            #[inline(always)]
556            fn from(value: Emulated<f16, $N, A>) -> Self {
557                Emulated(value.0.map(|v| v.reference_cast()), value.1)
558            }
559        }
560    };
561    ($from:ty => $to:ty, $N:literal) => {
562        impl<A> From<Emulated<$from, $N, A>> for Emulated<$to, $N, A> {
563            #[inline(always)]
564            fn from(value: Emulated<$from, $N, A>) -> Self {
565                Emulated(value.0.map(|v| v.into()), value.1)
566            }
567        }
568    };
569}
570
571impl_from!(f16 => f32, 1);
572impl_from!(f16 => f32, 2);
573impl_from!(f16 => f32, 4);
574impl_from!(f16 => f32, 8);
575impl_from!(f16 => f32, 16);
576
577impl_from!(u8 => i16, 16);
578impl_from!(u8 => i16, 32);
579
580impl_from!(i8 => i16, 16);
581impl_from!(i8 => i16, 32);
582
583impl_from!(i8 => i32, 1);
584impl_from!(i8 => i32, 4);
585
586impl_from!(u8 => i32, 1);
587impl_from!(u8 => i32, 4);
588
589/////////////////
590// Reinterpret //
591/////////////////
592
593macro_rules! impl_little_endian_transmute_cast {
594    (<$from:ty, $Nfrom:literal> => <$to:ty, $Nto:literal>) => {
595        #[cfg(target_endian = "little")]
596        impl<A> SIMDReinterpret<Emulated<$to, $Nto, A>> for Emulated<$from, $Nfrom, A>
597        where
598            A: arch::Sealed,
599        {
600            fn reinterpret_simd(self) -> Emulated<$to, $Nto, A> {
601                let array = self.0;
602                // SAFETY: This is only ever instantiated with arrays of primitive
603                // types that hold no resources, no padding, and are valid for all
604                // possible bit-patterns.
605                let casted = unsafe { std::mem::transmute::<[$from; $Nfrom], [$to; $Nto]>(array) };
606                Emulated(casted, self.1)
607            }
608        }
609    };
610}
611
612impl_little_endian_transmute_cast!(<u32, 8> => <i16, 16>);
613
614impl_little_endian_transmute_cast!(<i16, 8> => <u8, 16>);
615impl_little_endian_transmute_cast!(<u8, 16> => <i16, 8>);
616
617impl_little_endian_transmute_cast!(<u32, 16> => <u8, 64>);
618impl_little_endian_transmute_cast!(<u32, 16> => <i8, 64>);
619
620impl_little_endian_transmute_cast!(<u8, 64> => <u32, 16>);
621impl_little_endian_transmute_cast!(<i8, 64> => <u32, 16>);
622
623/////////////
624// Casting //
625/////////////
626
627macro_rules! impl_cast {
628    ($from:ty => $to:ty, $N:literal) => {
629        impl<A> SIMDCast<$to> for Emulated<$from, $N, A>
630        where
631            A: arch::Sealed,
632        {
633            type Cast = Emulated<$to, $N, A>;
634            #[inline(always)]
635            fn simd_cast(self) -> Self::Cast {
636                Emulated::from_arch_fn(self.arch(), |i| self.0[i].reference_cast())
637            }
638        }
639    };
640}
641
642impl_cast!(f16 => f32, 8);
643impl_cast!(f16 => f32, 16);
644
645impl_cast!(f32 => f16, 8);
646impl_cast!(f32 => f16, 16);
647
648impl_cast!(i32 => f32, 8);
649
650///////////////
651// SplitJoin //
652///////////////
653
654macro_rules! impl_splitjoin {
655    ($type:ty, $N:literal => $N2:literal) => {
656        impl<A> SplitJoin for Emulated<$type, $N, A>
657        where
658            A: Copy,
659        {
660            type Halved = Emulated<$type, $N2, A>;
661
662            #[inline(always)]
663            fn split(self) -> $crate::LoHi<Self::Halved> {
664                let $crate::LoHi { lo, hi } = self.0.split();
665                let arch = self.1;
666                $crate::LoHi::new(Emulated(lo, arch), Emulated(hi, arch))
667            }
668
669            #[inline(always)]
670            fn join(lohi: $crate::LoHi<Self::Halved>) -> Self {
671                Self($crate::LoHi::new(lohi.lo.0, lohi.hi.0).join(), lohi.lo.1)
672            }
673        }
674    };
675}
676
677impl_splitjoin!(i8, 32 => 16);
678impl_splitjoin!(i8, 64 => 32);
679
680impl_splitjoin!(i16, 16 => 8);
681impl_splitjoin!(i16, 32 => 16);
682
683impl_splitjoin!(i32, 8 => 4);
684impl_splitjoin!(i32, 16 => 8);
685
686impl_splitjoin!(u8, 32 => 16);
687impl_splitjoin!(u8, 64 => 32);
688
689impl_splitjoin!(u32, 8 => 4);
690impl_splitjoin!(u32, 16 => 8);
691impl_splitjoin!(u64, 4 => 2);
692
693impl_splitjoin!(f32, 16 => 8);
694impl_splitjoin!(f32, 8 => 4);
695
696impl_splitjoin!(f16, 16 => 8);
697
698//////////////
699// ZipUnzip //
700//////////////
701
702macro_rules! array_zipunzip {
703    ($N:literal) => {
704        impl<T: Copy> crate::traits::ZipUnzip for [T; $N] {
705            #[inline(always)]
706            fn zip(halves: $crate::LoHi<Self::Halved>) -> Self {
707                core::array::from_fn(|i| {
708                    if i % 2 == 0 {
709                        halves.lo[i / 2]
710                    } else {
711                        halves.hi[i / 2]
712                    }
713                })
714            }
715
716            #[inline(always)]
717            fn unzip(self) -> $crate::LoHi<Self::Halved> {
718                $crate::LoHi {
719                    lo: core::array::from_fn(|i| self[2 * i]),
720                    hi: core::array::from_fn(|i| self[2 * i + 1]),
721                }
722            }
723        }
724    };
725}
726
727array_zipunzip!(2);
728array_zipunzip!(4);
729array_zipunzip!(8);
730array_zipunzip!(16);
731array_zipunzip!(32);
732array_zipunzip!(64);
733
734macro_rules! impl_zipunzip {
735    ($type:ty, $N:literal => $N2:literal) => {
736        impl<A> ZipUnzip for Emulated<$type, $N, A>
737        where
738            A: Copy,
739        {
740            #[inline(always)]
741            fn zip(halves: $crate::LoHi<Self::Halved>) -> Self {
742                Self(
743                    $crate::LoHi::new(halves.lo.0, halves.hi.0).zip(),
744                    halves.lo.1,
745                )
746            }
747
748            #[inline(always)]
749            fn unzip(self) -> $crate::LoHi<Self::Halved> {
750                let $crate::LoHi { lo, hi } = self.0.unzip();
751                let arch = self.1;
752                $crate::LoHi::new(Emulated(lo, arch), Emulated(hi, arch))
753            }
754        }
755    };
756}
757
758impl_zipunzip!(i8, 32 => 16);
759impl_zipunzip!(i16, 16 => 8);
760impl_zipunzip!(i32, 8 => 4);
761impl_zipunzip!(u8, 32 => 16);
762impl_zipunzip!(u32, 8 => 4);
763impl_zipunzip!(f16, 16 => 8);
764
765///////////
766// Tests //
767///////////
768
769#[cfg(test)]
770mod test_emulated {
771    use half::f16;
772
773    use super::*;
774    use crate::{reference::ReferenceScalarOps, test_utils};
775
776    // Test loading logic - ensure that no out of bounds accesses are made.
777    // In particular, this is meant to be run under `Miri` to ensure that our guarantees
778    // regarding out-of-bounds accesses are honored.
779    #[test]
780    fn test_load() {
781        // Floating Point
782        #[cfg(not(miri))] // Miri does not have ph-to-ps conversion.
783        test_utils::test_load_simd::<f16, 8, Emulated<f16, 8>>(Scalar);
784        test_utils::test_load_simd::<f32, 4, Emulated<f32, 4>>(Scalar);
785        test_utils::test_load_simd::<f32, 8, Emulated<f32, 8>>(Scalar);
786
787        // Unsigned Integers
788        test_utils::test_load_simd::<u8, 8, Emulated<u8, 8>>(Scalar);
789        test_utils::test_load_simd::<u8, 16, Emulated<u8, 16>>(Scalar);
790
791        test_utils::test_load_simd::<u16, 4, Emulated<u16, 4>>(Scalar);
792        test_utils::test_load_simd::<u16, 8, Emulated<u16, 8>>(Scalar);
793        test_utils::test_load_simd::<u16, 16, Emulated<u16, 16>>(Scalar);
794
795        test_utils::test_load_simd::<u32, 2, Emulated<u32, 2>>(Scalar);
796        test_utils::test_load_simd::<u32, 4, Emulated<u32, 4>>(Scalar);
797        test_utils::test_load_simd::<u32, 8, Emulated<u32, 8>>(Scalar);
798
799        // Signed Integers
800        test_utils::test_load_simd::<i8, 8, Emulated<i8, 8>>(Scalar);
801        test_utils::test_load_simd::<i8, 16, Emulated<i8, 16>>(Scalar);
802
803        test_utils::test_load_simd::<i16, 4, Emulated<i16, 4>>(Scalar);
804        test_utils::test_load_simd::<i16, 8, Emulated<i16, 8>>(Scalar);
805        test_utils::test_load_simd::<i16, 16, Emulated<i16, 16>>(Scalar);
806
807        test_utils::test_load_simd::<i32, 2, Emulated<i32, 2>>(Scalar);
808        test_utils::test_load_simd::<i32, 4, Emulated<i32, 4>>(Scalar);
809        test_utils::test_load_simd::<i32, 8, Emulated<i32, 8>>(Scalar);
810    }
811
812    #[test]
813    fn test_store() {
814        // Floating Point
815        #[cfg(not(miri))] // Miri does not have ph-to-ps conversion.
816        test_utils::test_store_simd::<f16, 8, Emulated<f16, 8>>(Scalar);
817        test_utils::test_store_simd::<f32, 4, Emulated<f32, 4>>(Scalar);
818        test_utils::test_store_simd::<f32, 8, Emulated<f32, 8>>(Scalar);
819
820        // Unsigned Integers
821        test_utils::test_store_simd::<u8, 8, Emulated<u8, 8>>(Scalar);
822        test_utils::test_store_simd::<u8, 16, Emulated<u8, 16>>(Scalar);
823
824        test_utils::test_store_simd::<u16, 4, Emulated<u16, 4>>(Scalar);
825        test_utils::test_store_simd::<u16, 8, Emulated<u16, 8>>(Scalar);
826        test_utils::test_store_simd::<u16, 16, Emulated<u16, 16>>(Scalar);
827
828        test_utils::test_store_simd::<u32, 2, Emulated<u32, 2>>(Scalar);
829        test_utils::test_store_simd::<u32, 4, Emulated<u32, 4>>(Scalar);
830        test_utils::test_store_simd::<u32, 8, Emulated<u32, 8>>(Scalar);
831
832        // Signed Integers
833        test_utils::test_store_simd::<i8, 8, Emulated<i8, 8>>(Scalar);
834        test_utils::test_store_simd::<i8, 16, Emulated<i8, 16>>(Scalar);
835
836        test_utils::test_store_simd::<i16, 4, Emulated<i16, 4>>(Scalar);
837        test_utils::test_store_simd::<i16, 8, Emulated<i16, 8>>(Scalar);
838        test_utils::test_store_simd::<i16, 16, Emulated<i16, 16>>(Scalar);
839
840        test_utils::test_store_simd::<i32, 2, Emulated<i32, 2>>(Scalar);
841        test_utils::test_store_simd::<i32, 4, Emulated<i32, 4>>(Scalar);
842        test_utils::test_store_simd::<i32, 8, Emulated<i32, 8>>(Scalar);
843    }
844
845    // Only test a subset of constructors as all `Emulated` have the same implementation.
846    #[test]
847    fn test_constructors() {
848        test_utils::ops::test_splat::<u8, 64, Emulated<u8, 64>>(Scalar);
849        let x = Emulated::<u32, 8>::default(Scalar);
850        assert_eq!(x.to_underlying(), [0; 8]);
851
852        let x = Emulated::<u32, 8>::from_underlying(Scalar, [1; 8]);
853        assert_eq!(x.to_underlying(), [1; 8]);
854    }
855
856    // Wrap inside `Some` for compatibility with optional tests.
857    const SC: Option<Scalar> = Some(Scalar);
858
859    macro_rules! test_emulated {
860        ($type:ty, $N:literal) => {
861            test_utils::ops::test_add!(Emulated<$type, $N>, 0xba37c3f2cf666f87, SC);
862            test_utils::ops::test_sub!(Emulated<$type, $N>, 0xeb755abd230e5d80, SC);
863            test_utils::ops::test_mul!(Emulated<$type, $N>, 0x0a24ed76a54c3561, SC);
864            test_utils::ops::test_fma!(Emulated<$type, $N>, 0xa906c44505abe9ca, SC);
865            test_utils::ops::test_minmax!(Emulated<$type, $N>, 0x959522be5234d492, SC);
866
867            test_utils::ops::test_cmp!(Emulated<$type, $N>, 0x9b58e6cbd8330c2d, SC);
868            test_utils::ops::test_select!(Emulated<$type, $N>, 0x610aca3aa4d77c0a, SC);
869        };
870        (unsigned, $type:ty, $N:literal) => {
871            test_emulated!($type, $N);
872
873            test_utils::ops::test_bitops!(Emulated<$type, $N>, 0x14fc7841e66bd162, SC);
874        };
875        (signed, $type:ty, $N:literal) => {
876            test_emulated!($type, $N);
877
878            test_utils::ops::test_bitops!(Emulated<$type, $N>, 0x850435f89f86f3b0, SC);
879            test_utils::ops::test_abs!(Emulated<$type, $N>, 0x1842a2b86dfd9ecb, SC);
880        };
881    }
882
883    // Emulated arithmetic.
884    test_emulated!(f32, 1);
885    test_emulated!(f32, 4);
886    test_emulated!(f32, 8);
887    test_emulated!(f32, 16);
888    // test_emulated!(f64, 8);
889
890    // unsigned integer
891    test_emulated!(unsigned, u8, 16);
892
893    test_emulated!(unsigned, u16, 16);
894    test_emulated!(unsigned, u16, 32);
895
896    test_emulated!(unsigned, u32, 1);
897    test_emulated!(unsigned, u32, 4);
898    test_emulated!(unsigned, u32, 8);
899    test_emulated!(unsigned, u32, 16);
900
901    test_emulated!(unsigned, u64, 2);
902    test_emulated!(unsigned, u64, 4);
903    test_emulated!(unsigned, u64, 8);
904    test_emulated!(unsigned, u64, 16);
905
906    // signed integer
907    test_emulated!(signed, i8, 8);
908    test_emulated!(signed, i8, 16);
909
910    test_emulated!(signed, i16, 8);
911    test_emulated!(signed, i16, 16);
912
913    test_emulated!(signed, i32, 1);
914    test_emulated!(signed, i32, 4);
915    test_emulated!(signed, i32, 8);
916    test_emulated!(signed, i32, 16);
917
918    test_emulated!(signed, i64, 2);
919    test_emulated!(signed, i64, 4);
920    test_emulated!(signed, i64, 8);
921    test_emulated!(signed, i64, 16);
922
923    // Dot Products
924    test_utils::dot_product::test_dot_product!(
925        (Emulated<i16, 16>, Emulated<i16, 16>) => Emulated<i32, 8>, 0x3001f05604e96289, SC
926    );
927    test_utils::dot_product::test_dot_product!(
928        (Emulated<i16, 32>, Emulated<i16, 32>) => Emulated<i32, 16>, 0x137ce7a540d9b1a2, SC
929    );
930
931    test_utils::dot_product::test_dot_product!(
932        (Emulated<u8, 32>, Emulated<i8, 32>) => Emulated<i32, 8>, 0x3001f05604e96289, SC
933    );
934    test_utils::dot_product::test_dot_product!(
935        (Emulated<i8, 32>, Emulated<u8, 32>) => Emulated<i32, 8>, 0x3001f05604e96289, SC
936    );
937    test_utils::dot_product::test_dot_product!(
938        (Emulated<i8, 32>, Emulated<i8, 32>) => Emulated<i32, 8>, 0x3001f05604e96289, SC
939    );
940
941    test_utils::dot_product::test_dot_product!(
942        (Emulated<u8, 64>, Emulated<i8, 64>) => Emulated<i32, 16>, 0x3001f05604e96289, SC
943    );
944    test_utils::dot_product::test_dot_product!(
945        (Emulated<i8, 64>, Emulated<u8, 64>) => Emulated<i32, 16>, 0x3001f05604e96289, SC
946    );
947    test_utils::dot_product::test_dot_product!(
948        (Emulated<i8, 64>, Emulated<i8, 64>) => Emulated<i32, 16>, 0x3001f05604e96289, SC
949    );
950
951    test_utils::dot_product::test_dot_product!(
952        (Emulated<u8, 32>, Emulated<u8, 32>) => Emulated<u32, 8>, 0x3001f05604e96289, SC
953    );
954    test_utils::dot_product::test_dot_product!(
955        (Emulated<u8, 64>, Emulated<u8, 64>) => Emulated<u32, 16>, 0x3001f05604e96289, SC
956    );
957
958    // reductions
959    test_utils::ops::test_sumtree!(Emulated<f32, 1>, 0x410bad8207a8ccfc, SC);
960    test_utils::ops::test_sumtree!(Emulated<f32, 2>, 0xf2fc4e4bbd193493, SC);
961    test_utils::ops::test_sumtree!(Emulated<f32, 4>, 0x8034d5a0cd2be14d, SC);
962    test_utils::ops::test_sumtree!(Emulated<f32, 8>, 0x0f075940b7e3732c, SC);
963    test_utils::ops::test_sumtree!(Emulated<f32, 16>, 0x5b3cb860e3f02d3c, SC);
964
965    test_utils::ops::test_sumtree!(Emulated<i32, 4>, 0xf8c38f70a807e9d2, SC);
966    test_utils::ops::test_sumtree!(Emulated<i32, 8>, 0xf8aa4a7e7a273e80, SC);
967    test_utils::ops::test_sumtree!(Emulated<i32, 16>, 0x8d1a467fe835a9c5, SC);
968
969    test_utils::ops::test_sumtree!(Emulated<u32, 4>, 0x5e4cffc86a21e90d, SC);
970    test_utils::ops::test_sumtree!(Emulated<u32, 8>, 0xf43f19adb43bc611, SC);
971    test_utils::ops::test_sumtree!(Emulated<u32, 16>, 0xa43dfe10aa9de860, SC);
972
973    /////////////////
974    // conversions //
975    /////////////////
976
977    test_utils::ops::test_lossless_convert!(
978        Emulated<i8, 16> => Emulated<i16, 16>, 0x1b4f08a8b741d565, SC
979    );
980    test_utils::ops::test_lossless_convert!(
981        Emulated<i8, 32> => Emulated<i16, 32>, 0xdf6f41eb836d4f46, SC
982    );
983
984    test_utils::ops::test_lossless_convert!(
985        Emulated<i8, 1> => Emulated<i32, 1>, 0x318ceec0e9798353, SC
986    );
987    test_utils::ops::test_lossless_convert!(
988        Emulated<i8, 4> => Emulated<i32, 4>, 0x9f5e1a437f7e7f3f, SC
989    );
990
991    test_utils::ops::test_lossless_convert!(
992        Emulated<u8, 16> => Emulated<i16, 16>, 0x96611521fed02f98, SC
993    );
994    test_utils::ops::test_lossless_convert!(
995        Emulated<u8, 32> => Emulated<i16, 32>, 0x6749d3aa94effa04, SC
996    );
997
998    test_utils::ops::test_lossless_convert!(
999        Emulated<u8, 1> => Emulated<i32, 1>, 0x669cbd5c7bf6184e, SC
1000    );
1001    test_utils::ops::test_lossless_convert!(
1002        Emulated<u8, 4> => Emulated<i32, 4>, 0x75929494c5d333d0, SC
1003    );
1004
1005    ///////////
1006    // Casts //
1007    ///////////
1008
1009    test_utils::ops::test_cast!(Emulated<f16, 8> => Emulated<f32, 8>, 0x1e9e37b58fb3f1a8, SC);
1010    test_utils::ops::test_cast!(Emulated<f16, 16> => Emulated<f32, 16>, 0xd2b068a9bf3f9d24, SC);
1011
1012    test_utils::ops::test_cast!(Emulated<f32, 8> => Emulated<f16, 8>, 0xe9d2dd426d89699d, SC);
1013    test_utils::ops::test_cast!(Emulated<f32, 16> => Emulated<f16, 16>, 0x2b637e21afd9ef6c, SC);
1014
1015    test_utils::ops::test_cast!(Emulated<i32, 8> => Emulated<f32, 8>, 0x2b08e8ec7e49323b, SC);
1016
1017    //////////////
1018    // ZipUnzip //
1019    //////////////
1020
1021    test_utils::ops::test_zipunzip!(Emulated<i8, 32> => Emulated<i8, 16>, 0xa7c3e1f920b45d68, SC);
1022    test_utils::ops::test_zipunzip!(Emulated<i16, 16> => Emulated<i16, 8>, 0x6b8d2f0e41c7a593, SC);
1023    test_utils::ops::test_zipunzip!(Emulated<i32, 8> => Emulated<i32, 4>, 0x5f1a8c63d702be94, SC);
1024    test_utils::ops::test_zipunzip!(Emulated<u8, 32> => Emulated<u8, 16>, 0x92d5f4a83e1b07c6, SC);
1025    test_utils::ops::test_zipunzip!(Emulated<u32, 8> => Emulated<u32, 4>, 0xb6f30d8a52e4c197, SC);
1026    test_utils::ops::test_zipunzip!(Emulated<f16, 16> => Emulated<f16, 8>, 0x8b4e6d1fa07c9253, SC);
1027}