simd_trick/
lib.rs

1#![no_std]
2#![forbid(unsafe_code)]
3
4//! A crate to trick the optimizer into generating SIMD instructions.
5
6use core::{fmt, mem, ops};
7
8#[repr(C)]
9#[derive(Clone, Copy)]
10pub struct Simd<TArray>(TArray, <Self as Sealed>::Align) where Self: Vector;
11
12pub trait Vector: Copy + Sealed {
13    type Element: Copy;
14    type MaskVector: Vector;
15}
16
17fn simd<TArray>(array: TArray) -> Simd<TArray>
18    where Simd<TArray>: Vector
19{
20    Simd(array, Default::default())
21}
22
23impl<TArray> ops::Deref for Simd<TArray>
24    where Self: Vector
25{
26    type Target = TArray;
27    #[inline]
28    fn deref(&self) -> &TArray {
29        &self.0
30    }
31}
32impl<TArray> ops::DerefMut for Simd<TArray>
33    where Self: Vector
34{
35    #[inline]
36    fn deref_mut(&mut self) -> &mut TArray {
37        &mut self.0
38    }
39}
40
41macro_rules! define_vector_type {
42    ($($a:ident ($(@$m:ident $u:ident $($t:ident $p:ident $n:literal)+)+))+) => {$($($(
43        #[allow(non_camel_case_types)]
44        pub type $t = Simd<[$p; $n]>;
45
46        impl Sealed for Simd<[$p; $n]> {
47            type Align = $a;
48        }
49        impl Vector for Simd<[$p; $n]> {
50            type Element = $p;
51            type MaskVector = $m;
52        }
53
54        impl SimdImpl for Simd<[$p; $n]> {
55            fn as_slice(&self) -> &[Self::Element] {
56                &self.0
57            }
58
59            type Array = [Self::Element; $n];
60            #[inline]
61            fn repeat(value: Self::Element) -> Self {
62                simd([value; $n])
63            }
64            #[inline]
65            fn map(self, f: impl Fn($p) -> $p) -> Self {
66                simd(array_utils::map(self.0, f))
67            }
68            #[inline]
69            fn zip(self, other: Self, f: impl Fn($p, $p) -> $p) -> Self {
70                simd(array_utils::zip(self.0, other.0, f))
71            }
72
73            type Mask = <Self::MaskVector as Vector>::Element;
74            #[inline]
75            fn zip_mask(self, other: Self, f: impl Fn($p, $p) -> Self::Mask) -> Self::MaskVector {
76                simd(array_utils::zip(self.0, other.0, f))
77            }
78        }
79
80        impl From<Simd<[$p; $n]>> for [$p; $n] {
81            #[inline]
82            fn from(simd: Simd<[$p; $n]>) -> Self {
83                simd.0
84            }
85        }
86
87    )+
88
89        impl From<$m> for $u {
90            #[inline]
91            fn from(mask: $m) -> $u {
92                simd(array_utils::map(mask.0, Into::into))
93            }
94        }
95
96    )+)+};
97}
98define_vector_type!(
99    Align8 (
100        @m8x8 u8x8
101        i8x8 i8 8
102        u8x8 u8 8
103        m8x8 m8 8
104
105        @m16x4 u16x4
106        i16x4 i16 4
107        u16x4 u16 4
108        m16x4 m16 4
109
110        @m32x2 u32x2
111        i32x2 i32 2
112        u32x2 u32 2
113        m32x2 m32 2
114        f32x2 f32 2
115
116        @m64x1 u64x1
117        i64x1 i64 1
118        u64x1 u64 1
119        m64x1 m64 1
120        f64x1 f64 1
121    )
122    Align16 (
123        @m8x16 u8x16
124        i8x16 i8 16
125        u8x16 u8 16
126        m8x16 m8 16
127
128        @m16x8 u16x8
129        i16x8 i16 8
130        u16x8 u16 8
131        m16x8 m16 8
132
133        @m32x4 u32x4
134        i32x4 i32 4
135        u32x4 u32 4
136        m32x4 m32 4
137        f32x4 f32 4
138
139        @m64x2 u64x2
140        i64x2 i64 2
141        u64x2 u64 2
142        m64x2 m64 2
143        f64x2 f64 2
144    )
145    Align32 (
146        @m8x32 u8x32
147        i8x32 i8 32
148        u8x32 u8 32
149        m8x32 m8 32
150
151        @m16x16 u16x16
152        i16x16 i16 16
153        u16x16 u16 16
154        m16x16 m16 16
155
156        @m32x8 u32x8
157        i32x8 i32 8
158        u32x8 u32 8
159        m32x8 m32 8
160        f32x8 f32 8
161
162        @m64x4 u64x4
163        i64x4 i64 4
164        u64x4 u64 4
165        m64x4 m64 4
166        f64x4 f64 4
167    )
168    Align64 (
169        @m8x64 u8x64
170        i8x64 i8 64
171        u8x64 u8 64
172        m8x64 m8 64
173
174        @m16x32 u16x32
175        i16x32 i16 32
176        u16x32 u16 32
177        m16x32 m16 32
178
179        @m32x16 u32x16
180        i32x16 i32 16
181        u32x16 u32 16
182        m32x16 m32 16
183        f32x16 f32 16
184
185        @m64x8 u64x8
186        i64x8 i64 8
187        u64x8 u64 8
188        m64x8 m64 8
189        f64x8 f64 8
190    )
191);
192
193impl<TArray> From<TArray> for Simd<TArray>
194    where Self: Vector
195{
196    #[inline]
197    fn from(array: TArray) -> Self {
198        simd(array)
199    }
200}
201
202impl<TArray> Simd<TArray>
203    where Self: SimdImpl
204{
205    #[inline]
206    pub fn splat(value: <Self as Vector>::Element) -> Self {
207        Self::repeat(value)
208    }
209}
210
211impl<TArray> Default for Simd<TArray>
212where
213    Self: SimdImpl,
214    <Self as Vector>::Element: Default
215{
216    #[inline]
217    fn default() -> Self {
218        Self::splat(Default::default())
219    }
220}
221
222impl<TArray> fmt::Debug for Simd<TArray>
223where
224    Self: SimdImpl,
225    <Self as Vector>::Element: fmt::Debug
226{
227    #[inline]
228    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
229        fmt::Debug::fmt(self.as_slice(), f)
230    }
231}
232
233impl<TArray> Simd<TArray>
234where
235    Self: SimdImpl,
236    <Self as Vector>::Element: PartialOrd
237{
238    #[inline]
239    pub fn eq(self, other: Self) -> <Self as Vector>::MaskVector {
240        self.zip_mask(other, |a, b| (a == b).into())
241    }
242
243    #[inline]
244    pub fn ne(self, other: Self) -> <Self as Vector>::MaskVector {
245        self.zip_mask(other, |a, b| (a != b).into())
246    }
247
248    #[inline]
249    pub fn lt(self, other: Self) -> <Self as Vector>::MaskVector {
250        self.zip_mask(other, |a, b| (a < b).into())
251    }
252
253    #[inline]
254    pub fn gt(self, other: Self) -> <Self as Vector>::MaskVector {
255        self.zip_mask(other, |a, b| (a > b).into())
256    }
257
258    #[inline]
259    pub fn le(self, other: Self) -> <Self as Vector>::MaskVector {
260        self.zip_mask(other, |a, b| (a <= b).into())
261    }
262
263    #[inline]
264    pub fn ge(self, other: Self) -> <Self as Vector>::MaskVector {
265        self.zip_mask(other, |a, b| (a >= b).into())
266    }
267}
268
269impl<TArray> Simd<TArray>
270where
271    Self: SimdImpl,
272    <Self as Vector>::Element: Ord
273{
274    #[inline]
275    pub fn min(self, other: Self) -> Self {
276        self.zip(other, Ord::min)
277    }
278
279    #[inline]
280    pub fn max(self, other: Self) -> Self {
281        self.zip(other, Ord::max)
282    }
283}
284
285impl<TArray> Simd<TArray>
286where
287    Self: SimdImpl,
288    <Self as Vector>::Element: Integer
289{
290    #[inline]
291    pub fn wrapping_add(self, other: Self) -> Self {
292        self.zip(other, Integer::wrapping_add)
293    }
294
295    #[inline]
296    pub fn wrapping_sub(self, other: Self) -> Self {
297        self.zip(other, Integer::wrapping_sub)
298    }
299
300    #[inline]
301    pub fn wrapping_mul(self, other: Self) -> Self {
302        self.zip(other, Integer::wrapping_mul)
303    }
304
305    #[inline]
306    pub fn high_mul(self, other: Self) -> Self {
307        self.zip(other, Integer::high_mul)
308    }
309
310    #[inline]
311    pub fn saturating_add(self, other: Self) -> Self {
312        self.zip(other, Integer::saturating_add)
313    }
314
315    #[inline]
316    pub fn saturating_sub(self, other: Self) -> Self {
317        self.zip(other, Integer::saturating_sub)
318    }
319
320    #[inline]
321    pub fn count_ones(self) -> Self {
322        self.map(Integer::count_ones)
323    }
324
325    #[inline]
326    pub fn count_zeros(self) -> Self {
327        self.map(Integer::count_zeros)
328    }
329}
330
331impl<TArray> Simd<TArray>
332where
333    Self: SimdImpl,
334    <Self as Vector>::Element: SignedInteger
335{
336    #[inline]
337    pub fn wrapping_abs(self) -> Self {
338        self.map(SignedInteger::wrapping_abs)
339    }
340}
341
342impl<TArray> Simd<TArray>
343where
344    Self: SimdImpl,
345    <Self as Vector>::Element: FloatingPoint
346{
347    #[inline]
348    pub fn recip(self) -> Self {
349        self.map(FloatingPoint::recip)
350    }
351
352    #[inline]
353    pub fn to_degrees(self) -> Self {
354        self.map(FloatingPoint::to_degrees)
355    }
356
357    #[inline]
358    pub fn to_radians(self) -> Self {
359        self.map(FloatingPoint::to_radians)
360    }
361
362    #[inline]
363    pub fn min_naive(self, other: Self) -> Self {
364        self.zip(other, FloatingPoint::min_naive)
365    }
366
367    #[inline]
368    pub fn max_naive(self, other: Self) -> Self {
369        self.zip(other, FloatingPoint::max_naive)
370    }
371}
372
373macro_rules! forward_ops_as_zip {
374    ($($tr:ident $m:ident $(where $g:ident)? ,)+) => {$(
375        impl<TArray> ops::$tr for Simd<TArray>
376        where
377            Self: SimdImpl,
378            <Self as Vector>::Element: ops::$tr<Output = <Self as Vector>::Element>,
379            $( <Self as Vector>::Element: $g, )?
380        {
381            type Output = Self;
382            #[inline]
383            fn $m(self, other: Self) -> Self {
384                self.zip(other, ops::$tr::$m)
385            }
386        }
387    )+};
388}
389macro_rules! forward_ops_as_map {
390    ($($tr:ident $m:ident $(where $g:ident)? ,)+) => {$(
391        impl<TArray> ops::$tr for Simd<TArray>
392        where
393            Self: SimdImpl,
394            <Self as Vector>::Element: ops::$tr<Output = <Self as Vector>::Element>,
395            $( <Self as Vector>::Element: $g, )?
396        {
397            type Output = Self;
398            #[inline]
399            fn $m(self) -> Self {
400                self.map(ops::$tr::$m)
401            }
402        }
403    )+};
404}
405forward_ops_as_zip!(
406    BitAnd bitand,
407    BitOr bitor,
408    BitXor bitxor,
409
410    Add add where FloatingPoint,
411    Sub sub where FloatingPoint,
412    Mul mul where FloatingPoint,
413    Div div where FloatingPoint,
414    Rem rem where FloatingPoint,
415);
416forward_ops_as_map!(
417    Not not,
418
419    Neg neg where FloatingPoint,
420);
421
422use internals::*;
423mod internals {
424    pub trait Sealed {
425        type Align: Copy + Default;
426    }
427
428    pub trait SimdImpl: super::Vector {
429        fn as_slice(&self) -> &[Self::Element];
430
431        type Array;
432        fn repeat(value: Self::Element) -> Self;
433        fn map(self, f: impl Fn(Self::Element) -> Self::Element) -> Self;
434        fn zip(self, other: Self, f: impl Fn(Self::Element, Self::Element) -> Self::Element) -> Self;
435
436        type Mask: From<bool> + Into<bool>;
437        fn zip_mask(self, other: Self, f: impl Fn(Self::Element, Self::Element) -> Self::Mask) -> Self::MaskVector;
438    }
439
440    macro_rules! define_align_types {
441        ($($t:ident $n:literal)+) => {$(
442            #[repr(align($n))]
443            #[derive(Clone, Copy, Default)]
444            pub struct $t;
445        )+};
446    }
447    define_align_types!(
448        Align8 8
449        Align16 16
450        Align32 32
451        Align64 64
452    );
453
454    pub trait Integer {
455        fn wrapping_add(self, other: Self) -> Self;
456        fn wrapping_sub(self, other: Self) -> Self;
457        fn saturating_add(self, other: Self) -> Self;
458        fn saturating_sub(self, other: Self) -> Self;
459        fn wrapping_mul(self, other: Self) -> Self;
460        fn high_mul(self, other: Self) -> Self;
461        fn count_ones(self) -> Self;
462        fn count_zeros(self) -> Self;
463    }
464    pub trait SignedInteger: Integer {
465        fn wrapping_abs(self) -> Self;
466    }
467    pub trait FloatingPoint {
468        fn recip(self) -> Self;
469        fn to_degrees(self) -> Self;
470        fn to_radians(self) -> Self;
471        fn min_naive(self, other: Self) -> Self;
472        fn max_naive(self, other: Self) -> Self;
473    }
474}
475
476macro_rules! impl_integer {
477    ($($t:ident)+) => {$(
478        impl Integer for $t {
479            #[inline]
480            fn wrapping_add(self, other: Self) -> Self { $t::wrapping_add(self, other) }
481            #[inline]
482            fn wrapping_sub(self, other: Self) -> Self { $t::wrapping_sub(self, other) }
483            #[inline]
484            fn wrapping_mul(self, other: Self) -> Self { $t::wrapping_mul(self, other) }
485            #[inline]
486            fn high_mul(self, other: Self) -> Self { <$t as HighMul>::high_mul(self, other) }
487            #[inline]
488            fn saturating_add(self, other: Self) -> Self { $t::saturating_add(self, other) }
489            #[inline]
490            fn saturating_sub(self, other: Self) -> Self { $t::saturating_sub(self, other) }
491            #[inline]
492            fn count_ones(self) -> Self { $t::count_ones(self) as _ }
493            #[inline]
494            fn count_zeros(self) -> Self { $t::count_zeros(self) as _ }
495        }
496    )+};
497}
498impl_integer!(u8 u16 u32 u64);
499macro_rules! impl_signed_integer {
500    ($($t:ident)+) => {$(
501        impl_integer!($t);
502        impl SignedInteger for $t {
503            #[inline]
504            fn wrapping_abs(self) -> Self { $t::wrapping_abs(self) }
505        }
506    )+};
507}
508impl_signed_integer!(i8 i16 i32 i64);
509macro_rules! impl_floating_point {
510    ($($t:ident)+) => {$(
511        impl FloatingPoint for $t {
512            #[inline]
513            fn recip(self) -> Self { $t::recip(self) }
514            #[inline]
515            fn to_degrees(self) -> Self { $t::to_degrees(self) }
516            #[inline]
517            fn to_radians(self) -> Self { $t::to_radians(self) }
518            #[inline]
519            fn min_naive(self, other: Self) -> Self {
520                // this is not the same as fN::min -- it differs in NAN
521                // handling -- but this way gives the vminp instruction
522                if self < other { self } else { other }
523            }
524            #[inline]
525            fn max_naive(self, other: Self) -> Self {
526                // this is not the same as fN::max -- it differs in NAN
527                // handling -- but this way gives the vmaxp instruction
528                if self > other { self } else { other }
529            }
530        }
531    )+};
532}
533impl_floating_point!(f32 f64);
534
535macro_rules! define_mask_types {
536    ($($t:ident $p:ident)+) => {$(
537        impl From<bool> for $t {
538            #[inline]
539            fn from(b: bool) -> Self {
540                if b { $t::True } else { $t::False }
541            }
542        }
543        impl From<$t> for bool {
544            #[inline]
545            fn from(m: $t) -> bool {
546                match m {
547                    $t::False => false,
548                    $t::True => true,
549                }
550            }
551        }
552        impl From<$t> for $p {
553            #[inline]
554            fn from(m: $t) -> $p {
555                m as $p
556            }
557        }
558        impl Default for $t {
559            #[inline]
560            fn default() -> Self { $t::False }
561        }
562        impl array_utils::Zero for $t {
563            const ZERO: Self = $t::False;
564        }
565        #[repr($p)]
566        #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
567        #[allow(non_camel_case_types)]
568        pub enum $t {
569            False = (0 as $p),
570            True = !(0 as $p),
571        }
572    )+};
573}
574define_mask_types!(
575    m8 u8
576    m16 u16
577    m32 u32
578    m64 u64
579);
580
581trait HighMul {
582    fn high_mul(self, other: Self) -> Self;
583}
584macro_rules! impl_high_mul {
585    ($($t:ident $t2:ident)+) => {$(
586        impl HighMul for $t {
587            #[inline]
588            fn high_mul(self, other: Self) -> Self {
589                let wide = (self as $t2) * (other as $t2);
590                let high = wide >> (mem::size_of::<$t>() * 8);
591                high as $t
592            }
593        }
594    )+};
595}
596impl_high_mul!(
597    u8 u16
598    u16 u32
599    u32 u64
600    u64 u128
601    i8 i16
602    i16 i32
603    i32 i64
604    i64 i128
605);
606
607mod array_utils;
608
609#[cfg(test)]
610mod tests {
611    use super::*;
612
613    #[test]
614    fn it_works() {
615        let ones = i32x4::splat(1);
616        assert_eq!(ones[..], [1, 1, 1, 1]);
617
618        let a = i32x4::from([1, 2, 3, 4]);
619        let b = i32x4::from([45, 56, 78, 89]);
620        let c = b.wrapping_sub(a);
621        assert_eq!(c[..], [44, 54, 75, 85]);
622        let d = c.wrapping_add(Simd::splat(10));
623        assert_eq!(d[..], [54, 64, 85, 95]);
624    }
625
626    #[test]
627    fn defaults() {
628        i8x8::default();
629        i8x16::default();
630        i8x32::default();
631        i8x64::default();
632    }
633
634    #[test]
635    fn mask_comparison() {
636        assert!(m16::False < m16::True);
637        assert!(m16::False <= m16::True);
638    }
639}